From e487886d880cc9480b3cf819cd00d9c6078eb12e Mon Sep 17 00:00:00 2001 From: alanseed Date: Tue, 3 Jun 2025 20:30:13 +1000 Subject: [PATCH 01/12] Initial commit --- pysteps/mongo/README.md | 42 ++ pysteps/mongo/create_mongo_user.py | 44 ++ pysteps/mongo/delete_files.py | 125 +++++ pysteps/mongo/gridfs_io.py | 266 +++++++++++ pysteps/mongo/init_steps_db.py | 71 +++ pysteps/mongo/load_config.py | 283 ++++++++++++ pysteps/mongo/mongo_access.py | 184 ++++++++ pysteps/mongo/mongodb_port_forwarding.md | 128 ++++++ pysteps/mongo/nc_utils.py | 318 +++++++++++++ pysteps/mongo/pysteps_config.json | 133 ++++++ pysteps/mongo/write_ensemble.py | 335 ++++++++++++++ pysteps/mongo/write_nc_files.py | 307 +++++++++++++ pysteps/param/README.md | 45 ++ pysteps/param/broken_line.py | 124 +++++ pysteps/param/calibrate_ar_model.py | 321 +++++++++++++ pysteps/param/cascade_utils.py | 39 ++ pysteps/param/make_cascades.py | 281 ++++++++++++ pysteps/param/make_parameters.py | 413 +++++++++++++++++ pysteps/param/nwp_param_qc.py | 199 ++++++++ pysteps/param/pysteps_param.py | 369 +++++++++++++++ pysteps/param/shared_utils.py | 181 ++++++++ pysteps/param/steps_params.py | 555 +++++++++++++++++++++++ pysteps/param/stochastic_generator.py | 195 ++++++++ 23 files changed, 4958 insertions(+) create mode 100644 pysteps/mongo/README.md create mode 100644 pysteps/mongo/create_mongo_user.py create mode 100644 pysteps/mongo/delete_files.py create mode 100644 pysteps/mongo/gridfs_io.py create mode 100644 pysteps/mongo/init_steps_db.py create mode 100644 pysteps/mongo/load_config.py create mode 100644 pysteps/mongo/mongo_access.py create mode 100644 pysteps/mongo/mongodb_port_forwarding.md create mode 100644 pysteps/mongo/nc_utils.py create mode 100644 pysteps/mongo/pysteps_config.json create mode 100644 pysteps/mongo/write_ensemble.py create mode 100644 pysteps/mongo/write_nc_files.py create mode 100644 pysteps/param/README.md create mode 100644 pysteps/param/broken_line.py create mode 100644 pysteps/param/calibrate_ar_model.py create mode 100644 pysteps/param/cascade_utils.py create mode 100644 pysteps/param/make_cascades.py create mode 100644 pysteps/param/make_parameters.py create mode 100644 pysteps/param/nwp_param_qc.py create mode 100644 pysteps/param/pysteps_param.py create mode 100644 pysteps/param/shared_utils.py create mode 100644 pysteps/param/steps_params.py create mode 100644 pysteps/param/stochastic_generator.py diff --git a/pysteps/mongo/README.md b/pysteps/mongo/README.md new file mode 100644 index 000000000..1c97b162a --- /dev/null +++ b/pysteps/mongo/README.md @@ -0,0 +1,42 @@ +# mongo + +## Executable scripts + +### create_mongo_user.py + +This script is run by the database administrator to register a new user for the STEPS database + +### delete_files.py + +House keeping utility to delete records from the database + +### init_steps_db.py + +This script creates the STEPS database with the expected colletions and indices. + +### load_config.py + +This script loads the JSON configuration file into the STEPS database. + +### write_nc_files.py + +Read the database and generate the netCDF files for exporting to users. + +### write_ensembles.py + +An example of a product that is supplied to an end-user. + +## modules + +### gridfs_io.py + +Functions to read and write the binary data to GridFS + +### mongo_access.py + +Functions to read and write the metadata and parameters + +### nc_utils.py + +Functions to read and write the rain fields as CF netCDF binaries. + diff --git a/pysteps/mongo/create_mongo_user.py b/pysteps/mongo/create_mongo_user.py new file mode 100644 index 000000000..8f5bce982 --- /dev/null +++ b/pysteps/mongo/create_mongo_user.py @@ -0,0 +1,44 @@ +import secrets +import string +import os +from pymongo import MongoClient +from urllib.parse import quote_plus + +# === CONFIGURATION === +MONGO_HOST = "localhost" +MONGO_PORT = 27017 +AUTH_DB = "admin" +MONGO_ADMIN_USER = os.getenv("MONGO_USER") +MONGO_ADMIN_PASS = os.getenv("MONGO_PWD") +TARGET_DB = "STEPS" +PWD_DEFAULT = "c-bandBox" + +# === FUNCTIONS === +def generate_password(length=16): + alphabet = string.ascii_letters + string.digits + "!@#$%^&*()-_=+" + return ''.join(secrets.choice(alphabet) for _ in range(length)) + +def create_user(username, role="readWrite"): + # password = generate_password() + password = PWD_DEFAULT + client = MongoClient(f"mongodb://{quote_plus(MONGO_ADMIN_USER)}:{quote_plus(MONGO_ADMIN_PASS)}@{MONGO_HOST}:{MONGO_PORT}/?authSource={AUTH_DB}") + db = client[TARGET_DB] + + try: + db.command("createUser", username, pwd=password, roles=[{"role": role, "db": TARGET_DB}]) + print(f"\n✅ User '{username}' created with role '{role}'.\n") + print("Connection string:") + print(f" mongodb://{quote_plus(username)}:{quote_plus(password)}@{MONGO_HOST}:{MONGO_PORT}/{TARGET_DB}?authSource={TARGET_DB}\n") + except Exception as e: + print(f"❌ Failed to create user '{username}': {e}") + +# === ENTRY POINT === +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Create a MongoDB user with a random password.") + parser.add_argument("username", help="Username to create") + parser.add_argument("--role", default="readWrite", help="MongoDB role (default: readWrite)") + args = parser.parse_args() + create_user(args.username, args.role) + diff --git a/pysteps/mongo/delete_files.py b/pysteps/mongo/delete_files.py new file mode 100644 index 000000000..71279662d --- /dev/null +++ b/pysteps/mongo/delete_files.py @@ -0,0 +1,125 @@ +from models import get_db +from pymongo import MongoClient +import logging +import argparse +import gridfs +import pymongo +import datetime + +def is_valid_iso8601(time_str: str) -> bool: + """Check if the given string is a valid ISO 8601 datetime.""" + try: + datetime.datetime.fromisoformat(time_str) + return True + except ValueError: + return False + + +def main(): + parser = argparse.ArgumentParser( + description="Delete rainfall and/or state GridFS files.") + + parser.add_argument('-s', '--start', type=str, required=True, + help='Start time yyyy-mm-ddTHH:MM:SS') + parser.add_argument('-e', '--end', type=str, required=True, + help='End time yyyy-mm-ddTHH:MM:SS') + parser.add_argument('-n', '--name', type=str, required=True, + help='Name of domain [AKL]') + parser.add_argument('-p', '--product', type=str, required=True, + help='Name of product to delete [QPE, auckprec, qpesim]') + parser.add_argument('-c', '--cascade', default=False, action='store_true', + help='Delete the cascade files') + parser.add_argument('-r', '--rain', default=False, action='store_true', + help='Delete the rainfall files') + parser.add_argument('--params', default=False, action='store_true', + help='Delete the parameter documents') + + parser.add_argument('--dry_run', default=False, action='store_true', + help='Only list files that would be deleted, don’t delete them.') + + args = parser.parse_args() + logging.basicConfig(level=logging.INFO) + + if not (args.rain or args.cascade or args.params): + logging.warning("Nothing to delete: specify --rain, --cascade, or --params") + return + + # Validate and parse times + def parse_time(time_str): + if not is_valid_iso8601(time_str): + logging.error(f"Invalid time format: {time_str}") + exit(1) + t = datetime.datetime.fromisoformat(time_str) + return t.replace(tzinfo=datetime.timezone.utc) if t.tzinfo is None else t + + start_time = parse_time(args.start) + end_time = parse_time(args.end) + + name = args.name + product = args.product + dry_run = args.dry_run + + if product not in ["QPE", "auckprec", "qpesim", "nwpblend"]: + logging.error(f"Invalid product: {product}") + return + + db = get_db() + + def delete_files(collection_name): + coll = db[f"{collection_name}.files"] + fs = gridfs.GridFS(db, collection=collection_name) + + if product == "QPE": + query = { + "metadata.product": product, + "metadata.valid_time": {"$gte": start_time, "$lte": end_time} + } + else: + query = { + "metadata.product": product, + "metadata.base_time": {"$gte": start_time, "$lte": end_time} + } + + ids = list(coll.find(query, {"_id": 1,"filename":1})) + count = len(ids) + + if dry_run: + logging.info(f"[Dry Run] {count} files matched in {collection_name}. Listing _id values:") + for doc in ids: + logging.info(f" Would delete: {doc['filename']}") + else: + for doc in ids: + fs.delete(doc["_id"]) + logging.info(f"Deleted {count} files from {collection_name}") + + if args.rain: + delete_files(f"{name}.rain") + + if args.cascade: + delete_files(f"{name}.state") + + if args.params: + collection_name = f"{name}.params" + coll = db[collection_name] + if product == "QPE": + query = { + "metadata.product": product, + "metadata.valid_time": {"$gte": start_time, "$lte": end_time} + } + else: + query = { + "metadata.product": product, + "metadata.base_time": {"$gte": start_time, "$lte": end_time} + } + + ids = list(coll.find(query, {"_id": 1})) + count = len(ids) + if dry_run: + logging.info(f"[Dry Run] {count} files matched in {collection_name}") + else: + coll.delete_many(query) + logging.info(f"Deleted {count} files from {collection_name}") + + +if __name__ == "__main__": + main() diff --git a/pysteps/mongo/gridfs_io.py b/pysteps/mongo/gridfs_io.py new file mode 100644 index 000000000..dc5cf68da --- /dev/null +++ b/pysteps/mongo/gridfs_io.py @@ -0,0 +1,266 @@ +# Contains: store_cascade_to_gridfs, load_cascade_from_gridfs, load_rain_field, get_rain_fields, get_states +from io import BytesIO +import gridfs +import numpy as np +import pymongo +import copy +import datetime +from typing import Dict, Any, Optional, Union, Tuple + +def store_cascade_to_gridfs(db, name, cascade_dict, oflow, file_name, field_metadata): + """ + Stores a pysteps cascade decomposition dictionary into MongoDB's GridFS. + + Parameters: + db (pymongo.database.Database): The MongoDB database object. + cascade_dict (dict): The pysteps cascade decomposition dictionary. + oflow (np.ndarray): The optical flow field. + file_name (str): The (unique) name of the file to be stored. + field_metadata (dict): Additional metadata related to the field. + + Returns: + bson.ObjectId: The GridFS file ID. + """ + assert cascade_dict["domain"] == "spatial", "Only 'spatial' domain is supported." + state_col_name = f"{name}.state" + fs = gridfs.GridFS(db, collection=state_col_name) + + # Delete existing file with same filename + for old_file in fs.find({"filename": file_name}): + fs.delete(old_file._id) + + # Convert cascade_levels and oflow to a compressed format + buffer = BytesIO() + np.savez_compressed( + buffer, cascade_levels=cascade_dict["cascade_levels"], oflow=oflow) + buffer.seek(0) + + # Prepare metadata + metadata = { + "filename": file_name, + "domain": cascade_dict["domain"], + "normalized": cascade_dict["normalized"], + "transform": cascade_dict.get("transform"), + "threshold": cascade_dict.get("threshold"), + "zerovalue": cascade_dict.get("zerovalue") + } + metadata.update(field_metadata) # Merge additional metadata + + # Add optional statistics if available + if "means" in cascade_dict: + metadata["means"] = cascade_dict["means"] + if "stds" in cascade_dict: + metadata["stds"] = cascade_dict["stds"] + + # Store binary data and metadata atomically in GridFS + file_id = fs.put(buffer.getvalue(), filename=file_name, metadata=metadata) + + return file_id + + +def load_cascade_from_gridfs(db, name, file_name): + """ + Loads a pysteps cascade decomposition dictionary and optical flow from MongoDB's GridFS. + + Parameters: + db (pymongo.database.Database): The MongoDB database object. + file_name (str): The name of the file to retrieve. + + Returns: + tuple: (cascade_dict, oflow, metadata) + """ + state_col_name = f"{name}.state" + fs = gridfs.GridFS(db, collection=state_col_name) + + # Retrieve the file from GridFS + grid_out = fs.find_one({"filename": file_name}) + if grid_out is None: + raise ValueError(f"No file found with filename: {file_name}") + + # Retrieve metadata + metadata = grid_out.metadata + + # Read and decompress stored arrays + buffer = BytesIO(grid_out.read()) + npzfile = np.load(buffer) + + # Reconstruct cascade dictionary including the initial field transformation + cascade_dict = { + "cascade_levels": npzfile["cascade_levels"], + "domain": metadata["domain"], + "normalized": metadata["normalized"], + "transform": metadata.get("transform"), + "threshold": metadata.get("threshold"), + "zerovalue": metadata.get("zerovalue") + } + + # Restore optional statistics if they exist + if "means" in metadata: + cascade_dict["means"] = metadata["means"] + if "stds" in metadata: + cascade_dict["stds"] = metadata["stds"] + + oflow = npzfile["oflow"] # Optical flow field + + return cascade_dict, oflow, metadata + + +def load_rain_field(db, name, filename, nc_buf, metadata): + + # Check if the file exists, if yes then delete it + rain_col_name = f"{name}.rain" + + fs = gridfs.GridFS(db, collection=rain_col_name) + + existing_file = fs.find_one( + {"filename": filename}) + if existing_file: + fs.delete(existing_file._id) + + # Upload to GridFS + fs.put(nc_buf.tobytes(), + filename=filename, metadata=metadata) + + +def get_rain_fields(db: pymongo.MongoClient, name: str, query: dict): + rain_col_name = f"{name}.rain" + meta_col_name = f"{name}.rain.files" + fs = gridfs.GridFS(db, collection=rain_col_name) + meta_coll = db[meta_col_name] + + # Fetch matching filenames and metadata in a single query + fields_projection = {"_id": 0, "filename": 1, "metadata": 1} + results = meta_coll.find(query, projection=fields_projection).sort( + "filename", pymongo.ASCENDING) + + fields = [] + + # Process each matching file + for doc in results: + filename = doc["filename"] + + # Fetch metadata from GridFS + grid_out = fs.find_one({"filename": filename}) + if grid_out is None: + logging.warning(f"File {filename} not found in GridFS, skipping.") + continue + + rain_fs_metadata = grid_out.metadata if hasattr( + grid_out, "metadata") else {} + + # Copy relevant metadata + field_metadata = { + "filename": filename, + "product": rain_fs_metadata.get("product", "unknown"), + "domain": rain_fs_metadata.get("domain", "AKL"), + "ensemble": rain_fs_metadata.get("ensemble", None), + "base_time": rain_fs_metadata.get("base_time", None), + "valid_time": rain_fs_metadata.get("valid_time", None), + "mean": rain_fs_metadata.get("mean", 0), + "std_dev": rain_fs_metadata.get("std_dev", 0), + "wetted_area_ratio": rain_fs_metadata.get("wetted_area_ratio", 0) + } + + # Stream and decompress data + buffer = BytesIO(grid_out.read()) + rain_geodata, _, rain_data = read_nc(buffer) # Fixed variable name + + # Add the georeferencing metadata dictionary + field_metadata["geo_data"] = rain_geodata + + # Store the final record + record = {"rain": rain_data.copy( + ), "metadata": copy.deepcopy(field_metadata)} + fields.append(record) # Append the record to the list + + return fields + + +def get_states(db: pymongo.MongoClient, name: str, query: dict, + get_cascade: Optional[bool] = True, + get_optical_flow: Optional[bool] = True + ) -> Dict[Tuple[Any, Any, Any], Dict[str, Optional[Union[dict, np.ndarray]]]]: + """ + Retrieve state fields (cascade and/or optical flow) from a GridFS collection, + indexed by (valid_time, base_time, ensemble). + + Args: + db (pymongo.MongoClient): Database with the state collections. + name (str): Name prefix of the state collections. + query (dict): Mongo query for filtering state files. + get_cascade (bool, optional): Whether to retrieve cascade state. Defaults to True. + get_optical_flow (bool, optional): Whether to retrieve optical flow. Defaults to True. + + Returns: + dict: {(valid_time, base_time, ensemble): {"cascade": dict or None, + "optical_flow": np.ndarray or None, + "metadata": dict}} + """ + state_col_name = f"{name}.state" + meta_col_name = f"{name}.state.files" + fs = gridfs.GridFS(db, collection=state_col_name) + meta_coll = db[meta_col_name] + + fields = {"_id": 0, "filename": 1, "metadata": 1} + results = meta_coll.find(query, projection=fields).sort("filename", pymongo.ASCENDING) + + states = {} + + for doc in results: + state_file = doc["filename"] + metadata_dict = doc.get("metadata", {}) + + valid_time = metadata_dict.get("valid_time") + if valid_time is None: + logging.warning(f"No valid_time in metadata for file {state_file}, skipping.") + continue + if valid_time.tzinfo is None: + valid_time = valid_time.replace(tzinfo=datetime.timezone.utc) + + base_time = metadata_dict.get("base_time", "NA") + if base_time is not None and base_time.tzinfo is None: + base_time = base_time.replace(tzinfo=datetime.timezone.utc) + + ensemble = metadata_dict.get("ensemble", "NA") + + # Set missing base_time or ensemble to "NA" + if base_time is None: + base_time = "NA" + if ensemble is None: + ensemble = "NA" + + + grid_out = fs.find_one({"filename": state_file}) + if grid_out is None: + logging.warning(f"File {state_file} not found in GridFS, skipping.") + continue + + buffer = BytesIO(grid_out.read()) + npzfile = np.load(buffer) + + cascade_dict = None + if get_cascade: + cascade_dict = { + "cascade_levels": npzfile["cascade_levels"], + "domain": metadata_dict.get("domain"), + "normalized": metadata_dict.get("normalized"), + "transform": metadata_dict.get("transform"), + "threshold": metadata_dict.get("threshold"), + "zerovalue": metadata_dict.get("zerovalue"), + "means": metadata_dict.get("means"), + "stds": metadata_dict.get("stds"), + } + + oflow = None + if get_optical_flow: + oflow = npzfile["oflow"] + + key = (valid_time, base_time, ensemble) + states[key] = { + "cascade": copy.deepcopy(cascade_dict) if cascade_dict is not None else None, + "optical_flow": oflow.copy() if oflow is not None else None, + "metadata": copy.deepcopy(metadata_dict) + } + + return states + diff --git a/pysteps/mongo/init_steps_db.py b/pysteps/mongo/init_steps_db.py new file mode 100644 index 000000000..dba03afd9 --- /dev/null +++ b/pysteps/mongo/init_steps_db.py @@ -0,0 +1,71 @@ +from pymongo import MongoClient, ASCENDING +import argparse +import os +from pymongo import MongoClient +from urllib.parse import quote_plus + +# === Configuration === +AUTH_DB = "STEPS" +TARGET_DB = "STEPS" +MONGO_HOST = os.getenv("MONGO_HOST","localhost") +MONGO_PORT = os.getenv("MONGO_PORT",27017) +STEPS_USER = os.getenv("STEPS_USER","radar") +STEPS_PWD = os.getenv("STEPS_PWD","c-bandBox") + +# === Functions === +def setup_domain(db, domain_name): + print(f"⏳ Setting up domain: {domain_name}") + + for product in ["rain", "state"]: + files_coll = f"{domain_name}.{product}.files" + chunks_coll = f"{domain_name}.{product}.chunks" + + # Create empty collections (MongoDB creates on first insert, but we want indexes now) + db[files_coll].insert_one({"temp": True}) # insert dummy + db[chunks_coll].insert_one({"temp": True}) + + # Create compound index on files + db[files_coll].create_index([ + ("metadata.product", ASCENDING), + ("metadata.valid_time", ASCENDING), + ("metadata.base_time", ASCENDING), + ("metadata.ensemble", ASCENDING) + ], name="product_valid_base_ensemble_idx") + + # Index for GridFS pre-deletion lookups + db[files_coll].create_index([("filename", ASCENDING)], name="filename_idx") + + # Remove dummy record + db[files_coll].delete_many({"temp": True}) + db[chunks_coll].delete_many({"temp": True}) + + print(f"✅ {files_coll} and {chunks_coll} initialized with index") + + # Create a per-domain params collection + params_coll = f"{domain_name}.params" + db[params_coll].insert_one({"_test": True}) + db[params_coll].delete_many({"_test": True}) + print(f"✅ {params_coll} initialized") + +def setup_config(db): + config_coll = "config" + db[config_coll].insert_one({"_test": True}) + db[config_coll].delete_many({"_test": True}) + print(f"✅ {config_coll} initialized (shared)") + +# === Main === +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Initialize STEPS MongoDB structure") + parser.add_argument("domains", nargs="+", help="List of domain names to set up (e.g. AKL WLG CHC)") + args = parser.parse_args() + connect_string = f"mongodb://{quote_plus(STEPS_USER)}:{quote_plus(STEPS_PWD)}@{MONGO_HOST}:{MONGO_PORT}/STEPS?authSource={AUTH_DB}" + print(f"Connecting to {connect_string}") + client = MongoClient(connect_string) + db = client[TARGET_DB] + + for domain in args.domains: + setup_domain(db, domain) + + setup_config(db) + print("🎉 Setup complete.") + diff --git a/pysteps/mongo/load_config.py b/pysteps/mongo/load_config.py new file mode 100644 index 000000000..f20a92a47 --- /dev/null +++ b/pysteps/mongo/load_config.py @@ -0,0 +1,283 @@ +import argparse +import json +import logging +import os +from pathlib import Path +import datetime +from pymongo import MongoClient, errors +from urllib.parse import quote_plus +from models import get_db + +# Default pysteps configuration values +DEFAULT_PYSTEPS_CONFIG = { + "precip_threshold": None, + "extrapolation_method": "semilagrangian", + "decomposition_method": "fft", + "bandpass_filter_method": "gaussian", + "noise_method": "nonparametric", + "noise_stddev_adj": None, + "ar_order": 1, + "scale_break": None, + "velocity_perturbation_method": None, + "conditional": False, + "probmatching_method": "cdf", + "mask_method": "incremental", + "seed": None, + "num_workers": 1, + "fft_method": "numpy", + "domain": "spatial", + "extrapolation_kwargs": {}, + "filter_kwargs": {}, + "noise_kwargs": {}, + "velocity_perturbation_kwargs": {}, + "mask_kwargs": {}, + "measure_time": False, + "callback": None, + "return_output": True +} + +valid_product_list = ["qpesim", "auckprec", "nowcast", "nwpblend"] + +# Default output configuration +DEFAULT_OUTPUT_CONFIG = { + "qpesim":{ + "gridfs_out": True, + "nc_out": False, + "out_product": "qpesim", + "out_dir_name": None, + "out_file_name": "$N_$P_$V{%Y-%m-%dT%H:%M:%S}_$B{%Y-%m-%dT%H:%M:%S}_$E.nc" + }, + "auckprec":{ + "gridfs_out": True, + "nc_out": False, + "out_product": "auckprec", + "tmp_dir": "$HOME/tmp", + "out_dir_name": None, + "out_file_name": "$N_$P_$V{%Y-%m-%dT%H:%M:%S}_$B{%Y-%m-%dT%H:%M:%S}_$E.nc" + }, + "nowcast":{ + "gridfs_out": False, + "nc_out": False, + "out_product": "nowcast", + "out_dir_name": None, + "out_file_name": "$N_$P_$V{%Y-%m-%dT%H:%M:%S}_$B{%Y-%m-%dT%H:%M:%S}_$E.nc" + }, + "nwpblend":{ + "gridfs_out": True, + "nc_out": False, + "out_product": "nwpblend", + "out_dir_name": None, + "out_file_name": "$N_$P_$V{%Y-%m-%dT%H:%M:%S}_$B{%Y-%m-%dT%H:%M:%S}_$E.nc" + } +} + +# Default domain configuration +DEFAULT_DOMAIN_CONFIG = { + "n_rows": None, + "n_cols": None, + "p_size": None, + "start_x": None, + "start_y": None +} + +# Default projection configuration for NZ +DEFAULT_PROJECTION_CONFIG = { + "epsg": "EPSG:2193", + "name": "transverse_mercator", + "central_meridian": 173.0, + "latitude_of_origin": 0.0, + "scale_factor": 0.9996, + "false_easting": 1600000.0, + "false_northing": 10000000.0 +} + + +def file_exists(file_path: Path) -> bool: + """Check if the given file path exists.""" + return file_path.is_file() + + +def load_config(config_path: Path) -> dict: + """Load the full configuration from a JSON file, applying defaults for missing fields.""" + try: + with open(config_path, encoding="utf-8") as f: + config = json.load(f) + except json.JSONDecodeError: + logging.error(f"Error decoding JSON file: {config_path}") + return {} + + name = config.get("name", None) + if name is None: + logging.error( + f"Domain name not found") + return {} + + # Extract and validate pysteps configuration + pysteps_config = config.get("pysteps", {}) + if not isinstance(pysteps_config, dict): + logging.error( + f"Malformed pysteps configuration in {config_path}, expected a dictionary.") + return {} + + # Apply default values + for key, default_value in DEFAULT_PYSTEPS_CONFIG.items(): + if key not in pysteps_config: + logging.warning( + f"Missing key '{key}' in pysteps configuration, using default value.") + pysteps_config[key] = default_value + + # Validate mandatory keys + required_pysteps_keys = [ + "n_cascade_levels", "timestep", "kmperpixel" + ] + for key in required_pysteps_keys: + if key not in pysteps_config: + logging.error( + f"Missing mandatory key '{key}' in pysteps configuration.") + return {} + + # Extract and validate output configurations + output_config = config.get("output", {}) + if not isinstance(output_config, dict): + logging.error( + f"Malformed output configuration in {config_path}, expected a dictionary." + ) + return {} + + # Ensure "products" key exists and is a list + valid_product_list = ["qpesim", "auckprec", "nowcast", "nwpblend"] + + products = output_config.get("products", []) + if not isinstance(products, list): + logging.error( + f"Malformed 'products' key in output configuration, expected a list." + ) + return {} + + # Dictionary to store parsed output configurations + parsed_output_config = {} + + # Iterate over each product and extract its configuration + for product in products: + + if product not in valid_product_list: + logging.error( + f"Unexpected product found, '{product}' not in {valid_product_list}." + ) + continue + + product_config = output_config.get(product, {}) + + if not isinstance(product_config, dict): + logging.error( + f"Malformed configuration for product '{product}', expected a dictionary." + ) + continue + + # Merge with defaults + complete_config = DEFAULT_OUTPUT_CONFIG[product].copy() + complete_config.update(product_config) + + parsed_output_config[product] = complete_config + + # Extract and validate the domain location configuration + domain_config = config.get("domain", {}) + if not isinstance(domain_config, dict): + logging.error( + f"Malformed domain configuration in {config_path}, expected a dictionary.") + return {} + + for key, default_value in DEFAULT_DOMAIN_CONFIG.items(): + if key not in domain_config: + logging.error(f"Missing key '{key}' in domain configuration.") + return {} + + # Extract and validate the projection configuration - assumes CF fields for Transverse Mercator + projection_config = config.get("projection", {}) + if not isinstance(projection_config, dict): + logging.error( + f"Malformed projection configuration in {config_path}, expected a dictionary.") + return {} + + for key, default_value in DEFAULT_PROJECTION_CONFIG.items(): + if key not in projection_config: + logging.warning( + f"Missing key '{key}' in projection configuration, using default value") + projection_config[key] = default_value + + # Get the dynamic scaling if present + dynamic_scaling_config = config.get("dynamic_scaling", {}) + + # Only check for required keys if the dictionary is not empty + if dynamic_scaling_config: + required_ds_keys = ["central_wave_lengths", + "space_time_exponent", "lag2_constants", "lag2_exponents"] + for key in required_ds_keys: + if key not in dynamic_scaling_config: + logging.error( + f"Missing mandatory key '{key}' in dynamic_scaling configuration.") + return {} + + return { + "name": name, + "pysteps": pysteps_config, + "output": parsed_output_config, + "domain": domain_config, + "projection": projection_config, + "dynamic_scaling": dynamic_scaling_config + } + + +def insert_config_into_mongodb(config: dict): + """Insert the configuration into the MongoDB config collection.""" + record = { + "time": datetime.datetime.now(datetime.timezone.utc), + "config": config + } + + try: + db = get_db() + collection = db["config"] + + # Insert the record + result = collection.insert_one(record) + logging.info( + f"Configuration inserted successfully. Document ID: {result.inserted_id}") + + except errors.ServerSelectionTimeoutError: + logging.error( + "Failed to connect to MongoDB. Check if MongoDB is running and the URI is correct.") + except errors.PyMongoError as e: + logging.error(f"MongoDB error: {e}") + + +def main(): + parser = argparse.ArgumentParser( + description="Insert pysteps configuration into MongoDB" + ) + parser.add_argument('-c', '--config', type=Path, + help='Path to configuration file') + parser.add_argument('-v', '--verbose', action='store_true', + help='Enable verbose logging') + + args = parser.parse_args() + + # Configure logging + logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) + + # Validate config file path + if not args.config or not file_exists(args.config): + logging.error(f"Configuration file does not exist: {args.config}") + return + + # Load the full configuration + config = load_config(args.config) + + if config: + logging.info("Final loaded configuration:\n%s", + json.dumps(config, indent=2)) + insert_config_into_mongodb(config) + + +if __name__ == "__main__": + main() diff --git a/pysteps/mongo/mongo_access.py b/pysteps/mongo/mongo_access.py new file mode 100644 index 000000000..9382a156d --- /dev/null +++ b/pysteps/mongo/mongo_access.py @@ -0,0 +1,184 @@ +# Contains: get_db, get_config, get_parameters_df, get_parameters, to_utc_naive +from typing import Dict +import pandas as pd +import datetime +import os +import logging +import pymongo.collection +from pymongo import MongoClient +from urllib.parse import quote_plus +from models.steps_params import StochasticRainParameters +from models.cascade_utils import get_cascade_wavelengths + + +def get_parameters(query: Dict, param_coll) -> Dict: + """ + Get the parameters matching the query, indexed by valid_time. + + Args: + query (dict): MongoDB query dictionary. + param_coll (pymongo collection): Collection with the parameters. + + Returns: + dict: Dictionary {valid_time: StochasticRainParameters} + """ + result = {} + for doc in param_coll.find(query).sort("metadata.valid_time", pymongo.ASCENDING): + try: + param = StochasticRainParameters.from_dict(doc) + param.calc_corl() + result[param.valid_time] = param + except Exception as e: + print( + f"Warning: could not parse parameter for valid_time {doc.get('valid_time')}: {e}") + return result + + +def get_parameters_df(query: Dict, param_coll: pymongo.collection.Collection) -> pd.DataFrame: + """ + Retrieve STEPS parameters from the database and return a DataFrame + indexed by (valid_time, base_time, ensemble), using 'NA' as sentinel for missing values. + + Args: + query (dict): MongoDB query dictionary. + param_coll (pymongo.collection.Collection): MongoDB collection. + + Returns: + pd.DataFrame: Indexed by (valid_time, base_time, ensemble), with a 'param' column. + """ + records = [] + + for doc in param_coll.find(query).sort("metadata.valid_time", pymongo.ASCENDING): + try: + metadata = doc.get("metadata", {}) + if metadata is None: + continue + + if doc["cascade"]["lag1"] is None or doc["cascade"]["lag2"] is None: + continue + + valid_time = metadata.get("valid_time") + if valid_time is not None and valid_time.tzinfo is None: + valid_time = valid_time.replace(tzinfo=datetime.timezone.utc) + + base_time = metadata.get("base_time") + if base_time is None: + base_time = "NA" + elif base_time.tzinfo is None: + base_time = base_time.replace(tzinfo=datetime.timezone.utc) + + ensemble = metadata.get("ensemble") if metadata.get( + "ensemble") is not None else "NA" + param = StochasticRainParameters.from_dict(doc) + + param.calc_corl() + records.append({ + "valid_time": valid_time, + "base_time": base_time, + "ensemble": ensemble, + "param": param + }) + except Exception as e: + print( + f"Warning: could not parse parameter for {metadata.get('valid_time')}: {e}") + + if not records: + return pd.DataFrame(columns=["valid_time", "base_time", "ensemble", "param"]) + + df = pd.DataFrame(records) + return df + + +def get_config(db: pymongo.MongoClient, name: str) -> Dict: + """_summary_ + Return the most recent configuration setting + Args: + db (pymongo.MongoClient): Project database + + Returns: + Dict: Project configuration dictionary + """ + + config_coll = db["config"] + record = config_coll.find_one({'config.name': name}, sort=[ + ('time', pymongo.DESCENDING)]) + if record is None: + logging.error(f"Could not find configuration for domain {name}") + return None + + config = record['config'] + return config + + +def get_db(mongo_port=None): + MONGO_HOST = os.getenv("MONGO_HOST", "localhost") + # Use the function argument if provided, otherwise fall back to the environment variable, then default + MONGO_PORT = mongo_port if mongo_port is not None else int( + os.getenv("MONGO_PORT", 27017)) + + if mongo_port is None: + logging.info(f"Using MONGO_PORT from env: {MONGO_PORT}") + else: + logging.info(f"Using MONGO_PORT from argument: {mongo_port}") + + STEPS_USER = os.getenv("STEPS_USER", "radar") + STEPS_PWD = os.getenv("STEPS_PWD", "c-bandBox") + AUTH_DB = "STEPS" + TARGET_DB = "STEPS" + + conect_string = ( + f"mongodb://{quote_plus(STEPS_USER)}:{quote_plus(STEPS_PWD)}" + f"@{MONGO_HOST}:{MONGO_PORT}/STEPS?authSource={AUTH_DB}" + ) + logging.info(f"Connecting to {conect_string}") + client = MongoClient(conect_string) + db = client[TARGET_DB] + return db + + +def to_utc_naive(dt): + if dt.tzinfo is not None: + return dt.astimezone(datetime.timezone.utc).replace(tzinfo=None) + return dt + + +def get_central_wavelengths(db, name): + config = get_config(db, name) + n_levels = config["pysteps"].get("n_cascade_levels") + domain = config["domain"] + n_rows = domain.get("n_rows") + n_cols = domain.get("n_cols") + p_size = domain.get("p_size") + p_size_km = p_size / 1000.0 + domain_size_km = max(n_rows, n_cols) * p_size_km + + # Get central wavelengths + wavelengths_km = get_cascade_wavelengths( + n_levels, domain_size_km, p_size_km) + return wavelengths_km + +def get_base_time(valid_time, product, name, db): + # Get the base_time for the nwp run nearest to the valid_time in UTC zone + # Assume spin-up of 3 hours + start_base_time = valid_time - datetime.timedelta(hours=27) + end_base_time = valid_time - datetime.timedelta(hours=3) + base_time_query = { + "metadata.product": product, + "metadata.base_time": {"$gte": start_base_time, "$lte": end_base_time} + } + col_name = f"{name}.rain.files" + nwp_base_times = db[col_name].distinct( + "metadata.base_time", base_time_query) + + if nwp_base_times is None: + logging.warning( + f"Failed to find {product} data for {valid_time}") + return None + + nwp_base_times.sort(reverse=True) + base_time = nwp_base_times[0] + + if base_time.tzinfo is None: + base_time = base_time.replace(tzinfo=datetime.timezone.utc) + + return base_time diff --git a/pysteps/mongo/mongodb_port_forwarding.md b/pysteps/mongo/mongodb_port_forwarding.md new file mode 100644 index 000000000..75cb2ea3d --- /dev/null +++ b/pysteps/mongo/mongodb_port_forwarding.md @@ -0,0 +1,128 @@ +# Port Forwarding to Remote MongoDB on Localhost + +This guide explains how to open port **27018** on your local machine and forward it to a **remote MongoDB** instance running on port **27017**, either temporarily or permanently using `systemd` and `autossh`. + +--- + +## 🔁 SSH Port Forwarding (Temporary) + +To forward local port `27018` to the remote MongoDB server on port `27017`, run: + +```bash +ssh -L 27018:localhost:27017 your_user@remote_host +``` + +Once the tunnel is active, connect to MongoDB using: + +```bash +mongosh --port 27018 +``` + +Or with a URI: + +```bash +mongodb://localhost:27018 +``` + +--- + +## 🔄 Making Port Forwarding Persistent with systemd and autossh + +To create a self-healing SSH tunnel that auto-reconnects on failure, use `autossh` with a `systemd` user service. + +### 1. Install autossh + +On Fedora: + +```bash +sudo dnf install autossh +``` + +On Debian/Ubuntu: + +```bash +sudo apt install autossh +``` + +--- + +### 2. Set up SSH keys + +```bash +ssh-keygen +ssh-copy-id radar@remote_host +``` + +Make sure `ssh radar@remote_host` works without a password. + +--- + +### 3. Create the systemd user service + +Create the file: +`~/.config/systemd/user/mongodb-tunnel.service` + +```ini +[Unit] +Description=Persistent SSH tunnel to radar MongoDB +After=network.target + +[Service] +Environment=AUTOSSH_GATETIME=0 +ExecStart=/usr/bin/autossh -M 0 -N -L 27018:10.8.0.41:27017 radar +Restart=always +RestartSec=10 + +[Install] +WantedBy=default.target +``` + +> - Replace `10.8.0.41` with the IP of the remote MongoDB server (not necessarily `localhost` on the remote if it’s bound to a specific interface). +> - `radar` is your SSH alias or username. Make sure it’s configured in `~/.ssh/config` if using an alias. + +--- + +### 4. Enable and start the tunnel + +```bash +systemctl --user daemon-reexec +systemctl --user daemon-reload +systemctl --user enable mongodb-tunnel +systemctl --user start mongodb-tunnel +``` + +To check the status: + +```bash +systemctl --user status mongodb-tunnel +``` + +--- + +### 5. Optional: Ensure ssh-agent is running + +Add to your shell startup script: + +```bash +eval "$(ssh-agent -s)" +ssh-add ~/.ssh/id_rsa +``` + +Or use your desktop’s SSH key manager. + +--- + +## ✅ Verifying the Tunnel + +Once active, test it with: + +```bash +mongosh --port 27018 +``` + +--- + +## 🔐 Security Note + +- Keep your SSH key safe with a passphrase. +- Use `ufw`, `firewalld`, or similar to restrict access if needed. diff --git a/pysteps/mongo/nc_utils.py b/pysteps/mongo/nc_utils.py new file mode 100644 index 000000000..a8709ccb5 --- /dev/null +++ b/pysteps/mongo/nc_utils.py @@ -0,0 +1,318 @@ +""" + Refactored IO utilities for pysteps. +""" +import numpy as np +from pyproj import CRS +import netCDF4 +from datetime import datetime, timezone +from typing import Optional +import io + +def replace_extension(filename: str, new_ext: str) -> str: + return f"{filename.rsplit('.', 1)[0]}{new_ext}" + +def convert_timestamps_to_datetimes(timestamps): + """Convert POSIX timestamps to datetime objects.""" + return [datetime.fromtimestamp(ts, tz=timezone.utc) for ts in timestamps] + + +def write_netcdf(rain: np.ndarray, geo_data: dict, time: int): + """ + Write rain data as a NetCDF4 memory buffer. + + :param buffer: A BytesIO buffer to store the NetCDF data. + :param rain: Rainfall data as a NumPy array. + :param geo_data: Dictionary containing geo-referencing data with keys: + 'x', 'y', 'projection', and other metadata. + :param time: POSIX timestamp representing the time dimension. + :return: The BytesIO buffer containing the NetCDF data. + """ + x = geo_data['x'] + y = geo_data['y'] + # Default to WGS84 if not provided + projection = geo_data.get('projection', 'EPSG:4326') + + # Create an in-memory NetCDF dataset + ds = netCDF4.Dataset('inmemory.nc', mode='w', memory=1024) + + # Define dimensions + y_dim = ds.createDimension("y", len(y)) + x_dim = ds.createDimension("x", len(x)) + t_dim = ds.createDimension("time", 1) + + # Define coordinate variables + y_var = ds.createVariable("y", "f4", ("y",)) + x_var = ds.createVariable("x", "f4", ("x",)) + t_var = ds.createVariable("time", "i8", ("time",)) + + # Define rain variable + rain_var = ds.createVariable( + "rainfall", "i2", ("time", "y", "x"), zlib=True + ) # int16 with a fill value + rain_var.scale_factor = 0.1 + rain_var.add_offset = 0.0 + rain_var.units = "mm/h" + rain_var.long_name = "Rainfall rate" + rain_var.grid_mapping = "projection" + + # Assign coordinate values + y_var[:] = y + y_var.standard_name = "projection_y_coordinate" + y_var.units = "m" + + x_var[:] = x + x_var.standard_name = "projection_x_coordinate" + x_var.units = "m" + + t_var[:] = [time] + t_var.standard_name = "time" + t_var.units = "seconds since 1970-01-01T00:00:00Z" + + # Handle NaNs in rain data and assign to variable + rain[np.isnan(rain)] = -1 + rain_var[0,:, :] = rain + + # Define spatial reference (CRS) + crs = CRS.from_user_input(projection) + cf_grid_mapping = crs.to_cf() + + # Create spatial reference variable + spatial_ref = ds.createVariable("projection", "i4") + for key, value in cf_grid_mapping.items(): + setattr(spatial_ref, key, value) + + # Add global attributes + ds.Conventions = "CF-1.7" + ds.title = "Rainfall data" + ds.institution = "Weather Radar New Zealand Ltd" + ds.references = "" + ds.comment = "" + return ds.close() + +import io +import tempfile +import netCDF4 +import os +import numpy as np +from pyproj import CRS + +def write_netcdf_io(rain: np.ndarray, geo_data: dict, time: int) -> io.BytesIO: + """ + Write a NetCDF file to a temporary file, read it into memory, and return a BytesIO buffer. + """ + + x = geo_data['x'] + y = geo_data['y'] + projection = geo_data.get('projection', 'EPSG:4326') + + # Use NamedTemporaryFile to create a temp NetCDF file + with tempfile.NamedTemporaryFile(suffix=".nc", delete=False) as tmp: + tmp_path = tmp.name + + # Create NetCDF file on disk + ds = netCDF4.Dataset(tmp_path, mode='w', format='NETCDF4') + + # Define dimensions + ds.createDimension("y", len(y)) + ds.createDimension("x", len(x)) + ds.createDimension("time", 1) + + # Coordinate variables + y_var = ds.createVariable("y", "f4", ("y",)) + x_var = ds.createVariable("x", "f4", ("x",)) + t_var = ds.createVariable("time", "i8", ("time",)) + + # Rainfall variable + rain_var = ds.createVariable( + "rainfall", "i2", ("time", "y", "x"), + zlib=True, complevel=5, fill_value=-1 + ) + rain_var.scale_factor = 0.1 + rain_var.add_offset = 0.0 + rain_var.units = "mm/h" + rain_var.long_name = "Rainfall rate" + rain_var.grid_mapping = "projection" + + # Assign values + y_var[:] = y + x_var[:] = x + t_var[:] = [time] + rain_var[0, :, :] = np.nan_to_num(rain, nan=-1) + + y_var.standard_name = "projection_y_coordinate" + y_var.units = "m" + x_var.standard_name = "projection_x_coordinate" + x_var.units = "m" + t_var.standard_name = "time" + t_var.units = "seconds since 1970-01-01T00:00:00Z" + + # CRS + crs = CRS.from_user_input(projection) + cf_grid_mapping = crs.to_cf() + spatial_ref = ds.createVariable("projection", "i4") + for key, value in cf_grid_mapping.items(): + setattr(spatial_ref, key, value) + + # Global attributes + ds.Conventions = "CF-1.7" + ds.title = "Rainfall data" + ds.institution = "Weather Radar New Zealand Ltd" + ds.references = "" + ds.comment = "" + + ds.close() + + # Now read into memory + with open(tmp_path, "rb") as f: + nc_bytes = f.read() + + os.remove(tmp_path) + return io.BytesIO(nc_bytes) + + +def generate_geo_data(x, y, projection='EPSG:2193'): + """Generate geo-referencing data.""" + return { + "projection": projection, + "x": x, + "y": y, + "x1": np.round(x[0],decimals=0), + "x2": np.round(x[-1],decimals=0), + "y1": np.round(y[0],decimals=0), + "y2": np.round(y[-1],decimals=0), + "xpixelsize": np.round(x[1] - x[0],decimals=0), + "ypixelsize": np.round(y[1] - y[0],decimals=0), + "cartesian_unit": 'm', + "yorigin": 'lower', + "unit": 'mm/h', + "transform": None, + "threshold": 0.1, + "zerovalue": 0 + } + + +def read_nc(buffer: bytes): + """ + Read netCDF file from a memory buffer and return geo-referencing data and rain rates. + + :param buffer: Byte data of the NetCDF file from GridFS. + :return: Tuple containing geo-referencing data, valid times, and rain rate array. + """ + # Convert the byte buffer to a BytesIO object + byte_stream = io.BytesIO(buffer) + + # Open the NetCDF dataset + with netCDF4.Dataset('inmemory', mode='r', memory=byte_stream.getvalue()) as ds: + # Extract geo-referencing data + x = ds.variables["x"][:] + y = ds.variables["y"][:] + geo_data = generate_geo_data(x, y) + + # Convert timestamps to datetime + valid_times = convert_timestamps_to_datetimes(ds.variables["time"][:]) + + # Extract rain rates + rain_rate = ds.variables["rainfall"][:] + + # Replace invalid data with NaN and squeeze dimensions + rain_rate = np.squeeze(rain_rate) + rain_rate[rain_rate < 0] = np.nan + valid_times = np.squeeze(valid_times) + + return geo_data, valid_times, rain_rate + + +def validate_keys(keys, mandatory_keys): + """Validate the presence of mandatory keys.""" + missing_keys = [key for key in mandatory_keys if key not in keys] + if missing_keys: + raise KeyError(f"Missing mandatory keys: {', '.join(missing_keys)}") + +def make_nc_name_dt(out_file_name, name, out_product, valid_time, base_time, iens): + + vtime = valid_time + if vtime.tzinfo is None: + vtime = vtime.replace(tzinfo=timezone.utc) + vtime_stamp = vtime.timestamp() + + if base_time is not None: + btime = base_time + if btime.tzinfo is None: + btime = btime.replace(tzinfo=timezone.utc) + btime_stamp = btime.timestamp() + else: + btime_stamp = None + + fx_file_name = make_nc_name( + out_file_name, name, out_product, vtime_stamp, btime_stamp, iens) + return fx_file_name + + +def make_nc_name(name_template: str, name: str, prod: str, valid_time: int, + base_time: Optional[int] = None, ens: Optional[int] = None) -> str: + """ + Generate a file name using a template. + + :param name_template: Template for the file name + :param name: Name of the domain - Mandatory + :param prod: Name of the product - Mandatory + :param valid_time: Valid time of the field - Mandatory + :param run_time: NWP run time - Optional + :param ens: Ensemble member - Optional + :return: String with the file name + """ + result = name_template + + # Set up the valid time + vtime_info = datetime.fromtimestamp(valid_time, tz=timezone.utc) + + # Set up the NWP base time if available + btime_info = datetime.fromtimestamp( + base_time, tz=timezone.utc) if base_time is not None else None + + has_flag = True + while has_flag: + # Search for a flag + flag_posn = result.find("$") + if flag_posn == -1: + has_flag = False + else: + # Get the field type + f_type = result[flag_posn + 1] + + try: + # Add the valid and base times + if f_type in ['V', 'B']: + # Get the required format string + field_start = result.find("{", flag_posn + 1) + field_end = result.find("}", flag_posn + 1) + if field_start == -1 or field_end == -1: + raise ValueError(f"Invalid time format for flag '${ + f_type}' in template.") + + time_format = result[field_start + 1:field_end] + if f_type == 'V': + date_str = vtime_info.strftime(time_format) + elif f_type == 'B' and btime_info: + date_str = btime_info.strftime(time_format) + else: + date_str = "" + + # Replace the format field with the formatted time + result = result[:flag_posn] + \ + date_str + result[field_end + 1:] + elif f_type == 'P': + result = result[:flag_posn] + prod + result[flag_posn + 2:] + elif f_type == 'N': + result = result[:flag_posn] + name + result[flag_posn + 2:] + elif f_type == 'E' and ens is not None: + result = result[:flag_posn] + \ + f"{ens:02d}" + result[flag_posn + 2:] + else: + raise ValueError(f"Unknown or unsupported flag '${ + f_type}' in template.") + except Exception as e: + raise ValueError(f"Error processing flag '${ + f_type}': {str(e)}") + + return result \ No newline at end of file diff --git a/pysteps/mongo/pysteps_config.json b/pysteps/mongo/pysteps_config.json new file mode 100644 index 000000000..422b398dc --- /dev/null +++ b/pysteps/mongo/pysteps_config.json @@ -0,0 +1,133 @@ +{ + "name": "AKL", + "pysteps": { + "n_cascade_levels": 5, + "timestep": 600, + "kmperpixel": 2.0, + "precip_threshold": 1.0, + "transform":"dB", + "threshold":-10, + "zerovalue":-11, + "scale_break": 20, + "extrapolation_method": "semilagrangian", + "decomposition_method": "fft", + "bandpass_filter_method": "gaussian", + "noise_method": "nonparametric", + "noise_stddev_adj": null, + "ar_order": 2, + "velocity_perturbation_method": null, + "conditional": false, + "probmatching_method": "cdf", + "mask_method": "incremental", + "seed": null, + "num_workers": 1, + "fft_method": "numpy", + "domain": "spatial", + "extrapolation_kwargs": {}, + "filter_kwargs": {}, + "noise_kwargs": {}, + "velocity_perturbation_kwargs": {}, + "mask_kwargs": {}, + "measure_time": false, + "callback": null, + "return_output": true + }, + "output": { + "products":["qpesim", "auckprec","nowcast", "nwpblend"], + "qpesim":{ + "n_ens_members": 50, + "n_forecasts": 144, + "fx_update":10800, + "gridfs_out": true, + "nc_out": false, + "nwp_product": "auckprec", + "rad_product": "QPE", + "out_product": "qpesim", + "out_dir_name": null, + "out_file_name": "$N_$P_$V{%Y-%m-%dT%H:%M:%S}_$B{%Y-%m-%dT%H:%M:%S}_$E.nc" + }, + "auckprec":{ + "gridfs_out": true, + "nc_out": false, + "out_product": "auckprec", + "tmp_dir": "$HOME/tmp", + "out_dir_name": null, + "out_file_name": "$N_$P_$V{%Y-%m-%dT%H:%M:%S}_$B{%Y-%m-%dT%H:%M:%S}_$E.nc" + }, + "nowcast":{ + "n_ens_members": 25, + "n_forecasts": 12, + "fx_update":1800, + "gridfs_out": true, + "nc_out": false, + "nwp_product": "auckprec", + "rad_product": "QPE", + "out_product": "nowcast", + "out_dir_name": null, + "out_file_name": "$N_$P_$V{%Y-%m-%dT%H:%M:%S}_$B{%Y-%m-%dT%H:%M:%S}_$E.nc" + }, + "nwpblend":{ + "n_ens_members": 25, + "n_forecasts": 72, + "blend_width":180, + "gridfs_out": true, + "nc_out": false, + "nwp_product": "auckprec", + "rad_product": "QPE", + "out_product": "nwpblend", + "out_dir_name": null, + "out_file_name": "$N_$P_$V{%Y-%m-%dT%H:%M:%S}_$B{%Y-%m-%dT%H:%M:%S}_$E.nc" + } + + }, + "domain": { + "n_rows": 128, + "n_cols": 128, + "p_size": 2000, + "start_x": 1627000, + "start_y": 5854000 + }, + "projection": { + "epsg": "EPSG:2193", + "name": "transverse_mercator", + "central_meridian": 173.0, + "latitude_of_origin": 0.0, + "scale_factor": 0.9996, + "false_easting": 1600000.0, + "false_northing": 10000000.0 + }, + "dynamic_scaling": { + "central_wave_lengths": [ + 128.0, + 33.793468576034414, + 14.709461552298315, + 6.402665020067988, + 2.0 + ], + "space_time_exponent": 0.8471382397261171, + "lag2_constants": [ + 1.0019662531226008, + 0.9895839949795303, + 0.9544104783679567, + 0.8610790248307003, + 0.6447730842290677 + ], + "lag2_exponents": [ + 2.576213907004238, + 2.7557407261557945, + 2.6102093715829717, + 2.222301153284, + 1.6742864097867338 + ], + "cor_len_percentiles": [ + 95, + 50, + 5 + ], + "cor_len_pvals": [ + 894.5855025523205, + 253.723822128425, + 84.28166129094791 + ] + } +} diff --git a/pysteps/mongo/write_ensemble.py b/pysteps/mongo/write_ensemble.py new file mode 100644 index 000000000..7c06713e5 --- /dev/null +++ b/pysteps/mongo/write_ensemble.py @@ -0,0 +1,335 @@ +""" +Output an nc file with past and forcast ensemble +""" + +from models.mongo_access import get_db, get_config +from models.nc_utils import convert_timestamps_to_datetimes, make_nc_name +from pymongo import MongoClient +import logging +import argparse +import pymongo +from gridfs import GridFSBucket, NoFile +import numpy as np +import datetime +import netCDF4 +import pandas as pd +from pathlib import Path +import xarray as xr +from pyproj import CRS +import io + +import numpy as np +import netCDF4 +from pathlib import Path + + +def write_rainfall_netcdf(filename: Path, rainfall: np.ndarray, + x: np.ndarray, y: np.ndarray, + time: list, ensemble: np.ndarray): + """ + Write rainfall data to NetCDF using low-level netCDF4 interface. + - rainfall: 4D np.ndarray (ensemble, time, y, x), float32, mm/h with NaNs + - x, y: 1D arrays of projection coordinates in meters + - time: list of timezone-aware datetime.datetime objects + - ensemble: 1D array of ensemble member IDs (int) + """ + + n_ens, n_times, ny, nx = rainfall.shape + assert len(time) == n_times + assert len(ensemble) == n_ens + + with netCDF4.Dataset(filename, "w", format="NETCDF4") as ds: + # Create dimensions + ds.createDimension("ensemble", n_ens) + ds.createDimension("time", n_times) + ds.createDimension("y", ny) + ds.createDimension("x", nx) + + # Coordinate variables + x_var = ds.createVariable("x", "f4", ("x",)) + y_var = ds.createVariable("y", "f4", ("y",)) + t_var = ds.createVariable("time", "i4", ("time",)) + ens_var = ds.createVariable("ensemble", "i4", ("ensemble",)) + + x_var[:] = x + y_var[:] = y + ens_var[:] = ensemble + t_var[:] = netCDF4.date2num( + time, units="seconds since 1970-01-01T00:00:00", calendar="standard") + + x_var.units = "m" + x_var.standard_name = "projection_x_coordinate" + y_var.units = "m" + y_var.standard_name = "projection_y_coordinate" + t_var.units = "seconds since 1970-01-01 00:00:00" + t_var.standard_name = "time" + ens_var.long_name = "ensemble member" + + # CRS variable (dummy scalar) + crs_var = ds.createVariable("crs", "i4") + crs_var.grid_mapping_name = "transverse_mercator" + crs_var.scale_factor_at_central_meridian = 0.9996 + crs_var.longitude_of_central_meridian = 173.0 + crs_var.latitude_of_projection_origin = 0.0 + crs_var.false_easting = 1600000.0 + crs_var.false_northing = 10000000.0 + crs_var.semi_major_axis = 6378137.0 + crs_var.inverse_flattening = 298.257222101 + crs_var.spatial_ref = "EPSG:2193" + + # Rainfall variable (compressed int16 with scale) + rain_var = ds.createVariable( + "rainfall", "i2", ("ensemble", "time", "y", "x"), + zlib=True, complevel=5, fill_value=-1 + ) + rain_var.scale_factor = 0.1 + rain_var.add_offset = 0.0 + rain_var.units = "mm/h" + rain_var.long_name = "Rainfall rate" + rain_var.grid_mapping = "crs" + + rainfall[np.isnan(rainfall)] = -1 + rain_var[:, :, :, :] = rainfall + + +def is_valid_iso8601(time_str: str) -> bool: + """Check if the given string is a valid ISO 8601 datetime.""" + try: + datetime.datetime.fromisoformat(time_str) + return True + except ValueError: + return False + + +def get_filenames(db: MongoClient, name: str, query: dict): + meta_coll = db[f"{name}.rain.files"] + + # Fetch matching filenames and metadata in a single query + fields_projection = {"_id": 1, "filename": 1, "metadata": 1} + results = meta_coll.find(query, projection=fields_projection).sort( + "filename", pymongo.ASCENDING) + files = [] + for doc in results: + record = { + "valid_time": doc["metadata"]["valid_time"], + "base_time": doc["metadata"]["base_time"], + "ensemble": doc["metadata"]["ensemble"], + "_id": doc["_id"], + "filename": doc["filename"] + } + files.append(record) + + files_df = pd.DataFrame(files) + return files_df + + +def main(): + parser = argparse.ArgumentParser( + description="Write rainfall fields to a netCDF file") + parser.add_argument('-n', '--name', required=True, + help='Domain name (e.g., AKL)') + parser.add_argument('-b', '--base_time', type=str, required=True, + help='Base time yyyy-mm-ddTHH:MM:SS') + parser.add_argument('-d', '--directory', required=True, type=Path, + help='Path to output directory for the figures') + + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO) + + # Validate start and end time and read them in + if args.base_time and is_valid_iso8601(args.base_time): + base_time = datetime.datetime.fromisoformat(str(args.base_time)) + if base_time.tzinfo is None: + base_time = base_time.replace(tzinfo=datetime.timezone.utc) + else: + logging.error( + "Invalid base time format. Please provide a valid ISO 8601 time string.") + return + + file_dir = args.directory + if not file_dir.exists(): + logging.error(f"Invalid output diectory {file_dir}") + return + + name = args.name + db = get_db() + + # Get the domain geometry + config = get_config(db, name) + nwpblend_config = config["output"]["nwpblend"] + n_ens = nwpblend_config.get("n_ens_members") + n_fx = nwpblend_config.get("n_forecasts") + n_qpe = n_fx + ts_seconds = config["pysteps"]["timestep"] + ts = datetime.timedelta(seconds=ts_seconds) + + # Get the file names for the input data + start_qpe = base_time - n_qpe * ts + end_qpe = base_time + query = { + "metadata.product": "QPE", + "metadata.valid_time": {"$gte": start_qpe, "$lte": end_qpe} + } + qpe_df = get_filenames(db, name, query) + + start_blend = base_time + end_blend = base_time + n_fx*ts + + query = { + "metadata.product": "nwpblend", + "metadata.valid_time": {"$gt": base_time, "$lte": end_blend}, + "metadata.base_time": base_time + } + blend_df = get_filenames(db, name, query) + + qpe_fields = [] + qpe_times = [] + + bucket_name = f"{name}.rain" + bucket = GridFSBucket(db, bucket_name=bucket_name) + + for index, row in qpe_df.iterrows(): + filename = row["filename"] + with bucket.open_download_stream_by_name(filename) as stream: + buffer = stream.read() + byte_stream = io.BytesIO(buffer) + ds = netCDF4.Dataset('inmemory', mode='r', + memory=byte_stream.getvalue()) + + # Extract rain rate and handle 3D (time, y, x) or 2D (y, x) + rain_rate = ds.variables["rainfall"][:] + if rain_rate.ndim == 3: + rain_rate = rain_rate[0, :, :] # Take first time slice if present + + # Get valid time (assuming one timestamp per file) + time_var = ds.variables["time"][:] + valid_time = convert_timestamps_to_datetimes( + time_var)[0] # e.g., returns a list + + if index == 0: + y_ref = ds.variables["y"][:] + x_ref = ds.variables["x"][:] + else: + assert np.allclose(ds.variables["y"][:], y_ref) + assert np.allclose(ds.variables["x"][:], x_ref) + + # Accumulate + qpe_fields.append(rain_rate) + qpe_times.append(valid_time) + + # Convert to xarray.DataArray + qpe_array = xr.DataArray( + data=np.stack(qpe_fields), # shape: (time, y, x) + coords={"time": qpe_times, "y": y_ref, "x": x_ref}, + dims=["time", "y", "x"], + name="qpe" + ) + + # Ensure sorted and aligned valid_times across all ensemble members + ensembles = np.sort(blend_df["ensemble"].unique()) + + blend_times = np.sort(blend_df["valid_time"].unique()) + # ensures tz-aware datetime64[ns, UTC] + blend_times = pd.to_datetime(blend_times, utc=True) + # convert to native datetime.datetime + blend_times = [dt.to_pydatetime() for dt in blend_times] + + n_ens = len(ensembles) + n_time = len(blend_times) + ny, nx = y_ref.shape[0], x_ref.shape[0] + + # Initialize a 4D array (ensemble, time, y, x) + blend_data = np.full((n_ens, n_time, ny, nx), np.nan, dtype=np.float32) + + # Mapping from value to index + ensemble_to_idx = {ens: i for i, ens in enumerate(ensembles)} + time_to_idx = {vt: i for i, vt in enumerate(blend_times)} + + for index, row in blend_df.iterrows(): + filename = row["filename"] + ensemble = row["ensemble"] + with bucket.open_download_stream_by_name(filename) as stream: + buffer = stream.read() + byte_stream = io.BytesIO(buffer) + ds = netCDF4.Dataset('inmemory', mode='r', + memory=byte_stream.getvalue()) + + rain_rate = ds.variables["rainfall"][:] + if rain_rate.ndim == 3: + rain_rate = rain_rate[0, :, :] + + time_var = ds.variables["time"][:] + valid_time = convert_timestamps_to_datetimes(time_var)[0] + + assert np.allclose(ds.variables["y"][:], y_ref) + assert np.allclose(ds.variables["x"][:], x_ref) + + # Write into 4D array + ei = ensemble_to_idx[ensemble] + ti = time_to_idx[valid_time] + blend_data[ei, ti, :, :] = rain_rate + + # Build DataArray + blend_array = xr.DataArray( + data=blend_data, + coords={ + "ensemble": ensembles, + "time": blend_times, + "y": y_ref, + "x": x_ref + }, + dims=["ensemble", "time", "y", "x"], + name="blend" + ) + + qpe_times = list(qpe_array.coords["time"].values) + blend_times = list(blend_array.coords["time"].values) + combined_times = qpe_times + blend_times + + # Convert to tz-aware datetime.datetime + combined_times = pd.to_datetime(combined_times, utc=True) + combined_times = [t.to_pydatetime() for t in combined_times] + qpe_data = qpe_array.values + + # Tile across ensemble: + n_ens = blend_array.sizes["ensemble"] + qpe_broadcast = np.tile(qpe_data[None, :, :, :], (n_ens, 1, 1, 1)) + + # Stack QPE and forecasts: + combined_data = np.concatenate([qpe_broadcast, blend_array.values], axis=1) + + # Create combined xarray + combined_array = xr.DataArray( + data=combined_data, + coords={ + "ensemble": blend_array.coords["ensemble"], + "time": combined_times, + "y": y_ref, + "x": x_ref + }, + dims=["ensemble", "time", "y", "x"], + name="rainfall" + ) + template = "$N_$P_$V{%Y%m%d_%H%M%S}.nc" + tstamp = base_time.timestamp() + product = "qpe_nwpblend" + fname = make_nc_name(template, name, product, tstamp, None, None) + fdir = args.directory + file_name = fdir / fname + logging.info(f"Writing data to {file_name}") + + write_rainfall_netcdf( + filename=file_name, + rainfall=combined_array.values, + x=x_ref, + y=y_ref, + time=combined_times, + ensemble=combined_array.coords["ensemble"].values + ) + + return + + +if __name__ == "__main__": + main() diff --git a/pysteps/mongo/write_nc_files.py b/pysteps/mongo/write_nc_files.py new file mode 100644 index 000000000..53192a53a --- /dev/null +++ b/pysteps/mongo/write_nc_files.py @@ -0,0 +1,307 @@ +""" +Write rainfall grids to a netCDF file +""" + +from models import read_nc, make_nc_name_dt +from models import get_db, get_config +from pymongo import MongoClient +import logging +import argparse +import gridfs +import pymongo +import numpy as np +import datetime +import os +import netCDF4 +from pyproj import CRS +import pandas as pd + + +def is_valid_iso8601(time_str: str) -> bool: + """Check if the given string is a valid ISO 8601 datetime.""" + try: + datetime.datetime.fromisoformat(time_str) + return True + except ValueError: + return False + + +def get_base_times(db, base_time_query): + meta_coll = db["AKL.rain.files"] + base_times = list(meta_coll.distinct( + "metadata.base_time", base_time_query)) + return base_times + + +def get_valid_times(db, valid_time_query): + meta_coll = db["AKL.rain.files"] + valid_times = list(meta_coll.distinct( + "metadata.valid_time", valid_time_query)) + return valid_times + + +def get_rain_fields(db: pymongo.MongoClient, query: dict): + meta_coll = db["AKL.rain.files"] + + # Fetch matching filenames and metadata in a single query + fields_projection = {"_id": 1, "filename": 1, "metadata": 1} + results = meta_coll.find(query, projection=fields_projection).sort( + "filename", pymongo.ASCENDING) + files = [] + for doc in results: + record = {"_id": doc["_id"], + "valid_time": doc["metadata"]["valid_time"]} + files.append(record) + return files + + +def load_rain_field(db, file_id): + """Retrieve a specific rain field NetCDF file from GridFS and return as numpy array""" + fs = gridfs.GridFS(db, collection='AKL.rain') + file_obj = fs.get(file_id) + metadata = file_obj.metadata + data_bytes = file_obj.read() + geo_data, valid_time, rain_rate = read_nc(data_bytes) + if isinstance(valid_time, np.ndarray): + valid_time = valid_time.tolist() + return geo_data, metadata, valid_time, rain_rate + + +def write_netcdf(file_path: str, rain: np.ndarray, geo_data: dict, times: list[datetime.datetime], ensembles: list[int]) -> None: + """ + Write a set of rainfall grids to a CF netCDF file + Args: + file_path (str): Full path to the output file + rain (np.ndarray): Rainfall array. Shape is [ensemble, time, y, x] if ensembles is provided, + otherwise [time, y, x] + geo_data (dict): Geospatial information + times (list[datetime.datetime]): list of valid times + ensembles (list[int]): Optional list of valid ensemble numbers + """ + # Convert the times to seconds since 1970-01-01T00:00:00Z + time_stamps = [] + for time in times: + if time.tzinfo is None: + time = time.replace(tzinfo=datetime.timezone.utc) + time_stamp = time.timestamp() + time_stamps.append(time_stamp) + + x = geo_data['x'] + y = geo_data['y'] + projection = geo_data.get('projection', 'EPSG:4326') + + # Create NetCDF file on disk + with netCDF4.Dataset(file_path, mode='w', format='NETCDF4') as ds: + + # Define dimensions + ds.createDimension("y", len(y)) + ds.createDimension("x", len(x)) + ds.createDimension("time", len(times)) + + # Coordinate variables + y_var = ds.createVariable("y", "f4", ("y",)) + y_var[:] = y + y_var.standard_name = "projection_y_coordinate" + y_var.units = "m" + + x_var = ds.createVariable("x", "f4", ("x",)) + x_var[:] = x + x_var.standard_name = "projection_x_coordinate" + x_var.units = "m" + + t_var = ds.createVariable("time", "f8", ("time",)) + t_var[:] = time_stamps + t_var.standard_name = "time" + t_var.units = "seconds since 1970-01-01T00:00:00Z" + t_var.calendar = "standard" + + # Set up the ensemble if we have one + if ensembles is not None: + ds.createDimension("ensemble", len(ensembles)) + e_var = ds.createVariable("ensemble", "i4", ("ensemble",)) + e_var[:] = ensembles + e_var.standard_name = "ensemble" + e_var.units = "1" + + # Rainfall + if ensembles is None: + rain_var = ds.createVariable( + "rainfall", "i2", ("time", "y", "x"), + zlib=True, complevel=5, fill_value=-1 + ) + rain_var[:, :, :] = np.nan_to_num(rain, nan=-1) + + else: + rain_var = ds.createVariable( + "rainfall", "i2", ("ensemble", "time", "y", "x"), + zlib=True, complevel=5, fill_value=-1 + ) + rain_var[:, :, :, :] = np.nan_to_num(rain, nan=-1) + + rain_var.scale_factor = 0.1 + rain_var.add_offset = 0.0 + rain_var.units = "mm/h" + rain_var.long_name = "Rainfall rate" + rain_var.grid_mapping = "projection" + rain_var.coordinates = "time y x" if ensembles is None else "ensemble time y x" + + # CRS + crs = CRS.from_user_input(projection) + cf_grid_mapping = crs.to_cf() + spatial_ref = ds.createVariable("projection", "i4") + for key, value in cf_grid_mapping.items(): + setattr(spatial_ref, key, value) + + # Global attributes + ds.Conventions = "CF-1.10" + ds.title = "Rainfall data" + ds.institution = "Weather Radar New Zealand Ltd" + ds.references = "" + ds.comment = "" + return + + +def main(): + parser = argparse.ArgumentParser( + description="Write rainfall fields to a netCDF file") + + parser.add_argument('-s', '--start', type=str, required=True, + help='Start time yyyy-mm-ddTHH:MM:SS') + parser.add_argument('-e', '--end', type=str, required=True, + help='End time yyyy-mm-ddTHH:MM:SS') + parser.add_argument('-n', '--name', type=str, required=True, + help='Name of domain [AKL]') + parser.add_argument('-p', '--product', type=str, required=True, + help='Name of input product [QPE, auckprec]') + + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO) + + # Validate start and end time and read them in + if args.start and is_valid_iso8601(args.start): + start_time = datetime.datetime.fromisoformat(str(args.start)) + if start_time.tzinfo is None: + start_time = start_time.replace(tzinfo=datetime.timezone.utc) + else: + logging.error( + "Invalid start time format. Please provide a valid ISO 8601 time string.") + return + + if args.end and is_valid_iso8601(args.end): + end_time = datetime.datetime.fromisoformat(str(args.end)) + if end_time.tzinfo is None: + end_time = end_time.replace(tzinfo=datetime.timezone.utc) + else: + logging.error( + "Invalid start time format. Please provide a valid ISO 8601 time string.") + return + + name = str(args.name) + product = str(args.product) + valid_products = ["QPE", "auckprec", "qpesim"] + if product not in valid_products: + logging.error( + f"Invalid product. Please provide either {valid_products}.") + return + + db = get_db() + meta_coll = db["AKL.rain.files"] + + if product == "QPE": + file_id_query = {'metadata.product': product, + 'metadata.valid_time': {"$gte": start_time, "$lte": end_time}} + file_ids = get_rain_fields(db, file_id_query) + + out_grid = [] + valid_times = [] + geo_out = None + expected_shape = None + for file_id in file_ids: + geo_data, metadata, nc_times, rain_data = load_rain_field( + db, file_id["_id"]) + + if expected_shape is None: + expected_shape = rain_data.shape + elif rain_data.shape != expected_shape: + logging.error(f"Inconsistent rain_data shape: expected {expected_shape}, got {rain_data.shape}") + return + + out_grid.append(rain_data) + valid_times.append(nc_times) + if geo_out is None: + geo_out = geo_data + + # QPE files are named using the start and end valid times + name_template = "$N_$P_$V{%Y-%m-%dT%H:%M:%S}_$B{%Y-%m-%dT%H:%M:%S}.nc" + file_name = make_nc_name_dt( + name_template, name, product, start_time, end_time, None) + out_array = np.array(out_grid) + + logging.info(f"Writing {file_name}") + write_netcdf(file_name, out_array, geo_out, valid_times, None) + + else: + + # Get the list of base times in the time period + base_time_query = {'metadata.product': product, + 'metadata.base_time': {"$gte": start_time, "$lte": end_time}} + base_times = list(meta_coll.distinct( + "metadata.base_time", base_time_query)) + + # Loop over the base times that have been found + for base_time in base_times: + + # Get the sorted list of ensmble members and valid times for this base time + ensemble_query = {'metadata.product': product, + 'metadata.base_time': base_time} + ensembles = list(meta_coll.distinct( + "metadata.ensemble", ensemble_query)) + ensembles.sort() + ne = len(ensembles) + valid_times = list(meta_coll.distinct("metadata.valid_time", ensemble_query)) + valid_times.sort() + nt = len(valid_times) + # Loop over the ensembles and read in the grids + out_grid = [] + geo_out = None + expected_shape = None + for ensemble in ensembles: + + # Get all the valid times for this ensemble + file_id_query = {'metadata.product': product, + 'metadata.base_time': base_time, 'metadata.ensemble': ensemble} + # Check that the expected number of fields have been found + file_ids = get_rain_fields(db, file_id_query) + if len(valid_times) != len(file_ids): + logging.error(f"{base_time}:Expected {len(valid_times)} found {len(file_ids)} valid times") + + for file_id in file_ids: + geo_data, metadata, nc_times, rain_data = load_rain_field( + db, file_id["_id"]) + + if expected_shape is None: + expected_shape = rain_data.shape + elif rain_data.shape != expected_shape: + logging.error(f"Inconsistent rain_data shape: expected {expected_shape}, got {rain_data.shape}") + return + + out_grid.append(rain_data) + if geo_out is None: + geo_out = geo_data + + # Forecast files are named using their base time + name_template = "$N_$P_$V{%Y-%m-%dT%H:%M:%S}.nc" + ny,nx = expected_shape + file_name = make_nc_name_dt( + name_template, name, product, base_time, None, None) + out_array = np.array(out_grid).reshape(ne,nt,ny,nx) + + logging.info(f"Writing {file_name}") + write_netcdf(file_name, out_array, geo_out, valid_times, ensembles) + + return + + +if __name__ == "__main__": + main() diff --git a/pysteps/param/README.md b/pysteps/param/README.md new file mode 100644 index 000000000..fb5a56c84 --- /dev/null +++ b/pysteps/param/README.md @@ -0,0 +1,45 @@ +# param + +## executable scripts + +### pysteps_param.py + +This is the main script to generate an ensemble nowcast using the parametric algoithms + +### make_cascades.py + +Script to decompose and track the rainfall fields. The cascade states are written back into a GridFS bucket for later processing. + +### make_parameters.py + +Script to read the rainfall and cascade state data and calculate the STEPS parameters. The parameters are written back into a Mongo collection. + +### nwp_param_qc.py + +The NWP rainfall fields are derived from interpolating hourly ensembles onto a 10-min, 2 km resolution and contain significant errors as a result. This script cleans and smoothes the parameters that have been derived from the NWP ensemble and makes them ready for use. + +### calibrate_ar_model.py + +This scripts reads the radar rain fields and calibrates the dynamic scaling model. The output is a set of figures for quality assurance and a JSON file with the model parameters that can be included in the main configuration JSON file. + +## Modules + +### broken_line.py + +Implementation of the broken line model, not used at this stage but could be used to generate time series of STEPS parameters in the future. + +### cascade_utils.py + +A simple function to calculate the scale for each cascade level + +### shared_utils.py + +Functions that are likely to be used by the various forms of pySTEPS_param + +### steps_params.py + +Data class to manage the parameters and functions to operate on them + +### stochastic_generator.py + +Function to generate a single stochastic field given the parameters diff --git a/pysteps/param/broken_line.py b/pysteps/param/broken_line.py new file mode 100644 index 000000000..101f3ddd2 --- /dev/null +++ b/pysteps/param/broken_line.py @@ -0,0 +1,124 @@ + +import numpy as np +from typing import Optional + + +def broken_line(rain_mean: float, rain_std: float, time_step: int, duration: int, + h: Optional[float] = 0.60, q: Optional[float] = 0.85, + a_zero_min: Optional[float] = 1500, transform: Optional[bool] = True): + """ + Generate a time series of rainfall using the broken line model. + Based on Seed et al. (2000), WRR. + + Args: + rain_mean (float): Mean of time series (must be > 0) + rain_std (float): Standard deviation of time series (must be > 0) + time_step (int): Time step in minutes (must be > 0) + duration (int): Duration of time series in minutes (must be > time_step) + h (float): Scaling exponent (0 < h < 1) + q (float): Scale change ratio between lines (0 < q < 1) + a_zero_min (float): Maximum time scale in minutes (must be > 0) + transform (bool): Use log transformation to generate the time series + + Returns: + np.ndarray: Rainfall time series of specified length, or None on error + """ + + # Validate input parameters + if not isinstance(rain_mean, (float, int)) or rain_mean <= 0: + print("Error: rain_mean must be a positive number.") + return None + if not isinstance(rain_std, (float, int)) or rain_std <= 0: + print("Error: rain_std must be a positive number.") + return None + if not isinstance(time_step, int) or time_step <= 0: + print("Error: time_step must be a positive integer.") + return None + if not isinstance(duration, int) or duration <= time_step: + print("Error: duration must be an integer greater than time_step.") + return None + if not isinstance(h, (float, int)) or not (0 < h < 1): + print("Error: h must be a float in the range (0,1).") + return None + if not isinstance(q, (float, int)) or not (0 < q < 1): + print("Error: q must be a float in the range (0,1).") + return None + if not isinstance(a_zero_min, (float, int)) or a_zero_min <= 0: + print("Error: a_zero_min must be a positive number.") + return None + + # Number of time steps to generate + length = duration // time_step # Ensure integer division + + # Calculate the lognormal mean and variance + if transform: + ratio = rain_std / rain_mean + bl_mean = np.log(rain_mean) - 0.5 * np.log(ratio**2 + 1) + bl_var = np.log(ratio**2 + 1) + else: + bl_mean = rain_mean + bl_var = rain_std ** 2.0 + + # Compute number of broken lines + a_zero = a_zero_min / time_step + N = max(1, int(np.log(1.0 / a_zero) / np.log(q)) + 1) # Prevents N=0 + + # Compute variance at the outermost scale + var_zero = bl_var * (1 - q**h) / (1 - q**(N * h)) + + # Initialize the time series with mean + model = np.full(length, bl_mean) + + # Add broken lines at different scales + for p in range(N): + break_step = a_zero * q**p + line_stdev = np.sqrt(var_zero * q**(p * h)) + line = make_line(line_stdev, break_step, length) + model += line + + # Transform back to rainfall space if needed + if transform: + rain = np.exp(model) + return rain + else: + return model + + +def make_line(std_dev, break_step, length): + """ + Generate a piecewise linear process with random breakpoints. + + Args: + std_dev (float): Standard deviation for generating y-values. + break_step (float): Distance between breakpoints. + length (int): Length of the output array. + + Returns: + np.ndarray: Interpolated line of given length. + """ + + # Generate random breakpoints + rng = np.random.default_rng(None) + + if break_step < 1: + y = rng.normal(0, std_dev, length) # Scaled correctly + return y + + # Number of breakpoints + n_points = 3 + int(length / break_step) + y = rng.normal(0, 1.5 * std_dev, n_points) # Scaled correctly + + # Generate x-coordinates with random offset + offset = rng.uniform(-break_step, 0) + x = [offset + break_step*ia for ia in range(n_points)] + + # Interpolate onto full time series + x_out = np.arange(length) + line = np.interp(x_out, x, y) + + # Normalize the standard deviation + line_std = np.std(line) + if line_std > 0: + line = (line - np.mean(line)) * (std_dev / line_std) + + return line diff --git a/pysteps/param/calibrate_ar_model.py b/pysteps/param/calibrate_ar_model.py new file mode 100644 index 000000000..6c0de1dca --- /dev/null +++ b/pysteps/param/calibrate_ar_model.py @@ -0,0 +1,321 @@ +""" +Estimate the parameters that manage the AR model using the observed QPE data +Write out the dynamic scaling configuration to a JSON file +""" + +from models import get_db, get_config +import datetime +import numpy as np +import pymongo +import argparse +import logging +import pandas as pd +from scipy.optimize import curve_fit +from pathlib import Path +import statsmodels.api as sm +from pysteps.cascade.bandpass_filters import filter_gaussian +import json +import matplotlib.pyplot as plt + + +def is_valid_iso8601(time_str: str) -> bool: + try: + datetime.datetime.fromisoformat(time_str) + return True + except ValueError: + return False + + +def get_auto_corls(db: pymongo.MongoClient, product: str, start_time: datetime.datetime, end_time: datetime.datetime): + params_col = db["AKL.params"] + query = { + 'metadata.valid_time': {'$gte': start_time, '$lte': end_time}, + 'metadata.product': product + } + projection = {"_id": 0, "metadata": 1, "cascade": 1} + data_cursor = params_col.find(query, projection=projection) + data_list = list(data_cursor) + + logging.info(f'Found {len(data_list)} documents') + + rows = [] + for doc in data_list: + row = {"valid_time": doc["metadata"]["valid_time"]} + + lag1 = doc.get("cascade", {}).get("lag1") + lag2 = doc.get("cascade", {}).get("lag2") + stds = doc.get("cascade", {}).get("stds") + + if lag1 is None or lag2 is None or stds is None: + continue + + for ia, val in enumerate(lag1): + row[f"lag1_{ia}"] = val + for ia, val in enumerate(lag2): + row[f"lag2_{ia}"] = val + for ia, val in enumerate(stds): + row[f"stds_{ia}"] = val + + rows.append(row) + + return pd.DataFrame(rows) + + +def power_law(x, a, b): + return a * np.power(x, b) + + +def fit_power_law(qpe_df, lev): + lag1_vals = qpe_df[f"lag1_{lev}"].values + lag2_vals = qpe_df[f"lag2_{lev}"].values + + q05 = np.quantile(lag1_vals, 0.05) + mask = lag1_vals > q05 + x = lag1_vals[mask] + y = lag2_vals[mask] + + coefs, _ = curve_fit(power_law, x, y) + + y_pred = power_law(x, coefs[0], coefs[1]) + ss_res = np.sum((y - y_pred) ** 2) + ss_tot = np.sum((y - np.mean(y)) ** 2) + r_squared = 1 - ss_res / ss_tot + + return (coefs[0], coefs[1], r_squared) + +def is_stationary(phi1, phi2): + return abs(phi2) < 1 and (phi1 + phi2) < 1 and (phi2 - phi1) < 1 + +def correlation_length(lag1, lag2, tol=1e-4, max_lag=1000): + if lag1 is None or lag2 is None: + return np.nan + + A = np.array([[1.0, lag1], [lag1, 1.0]]) + b = np.array([lag1, lag2]) + + try: + phi = np.linalg.solve(A, b) + except np.linalg.LinAlgError: + return np.nan + + phi1, phi2 = phi + if not is_stationary(phi1, phi2): + return np.nan + + rho_vals = [1.0, lag1, lag2] + for _ in range(3, max_lag): + next_rho = phi1 * rho_vals[-1] + phi2 * rho_vals[-2] + if abs(next_rho) < tol: + break + rho_vals.append(next_rho) + + return np.trapz(rho_vals, dx=10) + +import os +def generate_qaqc_plots(cor_len_df, ht, scales, lag2_constants, lag2_exponents, n_levels, output_prefix=""): + # Set up color map + cmap = plt.colormaps.get_cmap('tab10') + + # Create output directory if it doesn't exist + figs_dir = os.path.join("..", "figs") + os.makedirs(figs_dir, exist_ok=True) + + # Plot fitted vs observed lag1 & lag2 for three percentiles of cl_0 + percentiles = [95, 50, 5] + cl0_values = cor_len_df["cl_0"].values + pvals = np.percentile(cl0_values, percentiles) + L = scales[0] # reference scale in km + + for pval, pstr in zip(pvals, percentiles): + # Find closest row + idx = (np.abs(cl0_values - pval)).argmin() + row = cor_len_df.iloc[idx] + time_str = row["valid_time"].strftime("%Y-%m-%d %H:%M") + + # Set up the scaling correlation lengths for this case + T_ref = row["cl_0"] # T(t, L) at largest scale + T_levels = [T_ref * (l / L) ** ht for l in scales] + dt = 10 + + obs_lag1 = [] + obs_lag2 = [] + fit_lag1 = [] + fit_lag2 = [] + levels = [] + for ilevel in range(n_levels): + lag1 = row[f"lag1_{ilevel}"] + lag2 = row[f"lag2_{ilevel}"] + + a = lag2_constants[ilevel] + b = lag2_exponents[ilevel] + pl_lag1 = np.exp(-dt / T_levels[ilevel]) + pl_lag2 = a * (pl_lag1 ** b) + obs_lag1.append(lag1) + obs_lag2.append(lag2) + fit_lag1.append(pl_lag1) + fit_lag2.append(pl_lag2) + levels.append(ilevel) + + plt.figure(figsize=(6, 4)) + color_lag1 = cmap(1) + color_lag2 = cmap(2) + + plt.plot(scales, obs_lag1, 'x-', label='Observed lag1', color=color_lag1) + plt.plot(scales, fit_lag1, 'o-', label='Fit lag1',color= color_lag1) + plt.plot(scales, obs_lag2, 'x--', label='Observed lag2', color=color_lag2) + plt.plot(scales, fit_lag2, 'o--', label='Fit lag2', color=color_lag2) + + plt.xscale("log") + plt.xlabel("Scale (km)") + plt.ylabel("Autocorrelation") + plt.title(f"Fit vs Obs @ cl_0 ~ {pstr}th percentile \n{time_str}, corl len = {T_ref:.0f} min") + plt.grid(True, which="both", ls="--", alpha=0.6) + plt.legend() + plt.tight_layout() + + filename = f"{output_prefix}lags_{pstr}th_percentile.png" + plt.savefig(os.path.join(figs_dir, filename)) + plt.close() + + +def main(): + parser = argparse.ArgumentParser( + description="Calculate the parameters for the dynamic scaling model") + parser.add_argument('-s', '--start', type=str, + required=True, help='Start time yyyy-mm-ddTHH:MM:SS') + parser.add_argument('-e', '--end', type=str, required=True, + help='End time yyyy-mm-ddTHH:MM:SS') + parser.add_argument('-n', '--name', type=str, + required=True, help='Name of domain [AKL]') + parser.add_argument('-c', '--config', type=Path, required=True, + help='Path to output dynamic scaling configuration file') + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO) + + # Parse start and end time + try: + start_time = datetime.datetime.fromisoformat( + args.start).replace(tzinfo=datetime.timezone.utc) + end_time = datetime.datetime.fromisoformat( + args.end).replace(tzinfo=datetime.timezone.utc) + except ValueError: + logging.error("Invalid ISO 8601 date format for start or end time.") + return + + name = args.name + config_file_name = args.config + product = "QPE" + + db = get_db() + config = get_config(db, name) + n_rows = config["domain"]["n_rows"] + n_cols = config["domain"]["n_cols"] + n_levels = config["pysteps"]["n_cascade_levels"] + kmperpixel = config["pysteps"]["kmperpixel"] + + corl_df = get_auto_corls(db, product, start_time, end_time).dropna() + + if corl_df.empty: + logging.error( + "No valid correlation data found in the selected time range.") + return + + lag2_constants, lag2_exponents = [], [] + + for ilevel in range(n_levels): + a, b, rsq = fit_power_law(corl_df, ilevel) + if rsq < 0.5: + logging.info( + f"Warning: Rsq = {rsq:.2f}, using default power law for level {ilevel}") + a, b = 1.0, 2.4 + logging.info( + f"Level {ilevel}: lag2 = {a:.3f} * lag1^{b:.3f}, R² = {rsq:.2f}") + lag2_constants.append(a) + lag2_exponents.append(b) + + records = [] + for ilevel in range(n_levels): + lag1_col = f"lag1_{ilevel}" + lag2_col = f"lag2_{ilevel}" + + level_df = corl_df[["valid_time", lag1_col, lag2_col]].copy() + level_df["pl_lag2"] = lag2_constants[ilevel] * \ + np.power(level_df[lag1_col], lag2_exponents[ilevel]) + level_df[f"cl_{ilevel}"] = level_df.apply( + lambda row: correlation_length(row[lag1_col], row["pl_lag2"]), axis=1) + records.append( + level_df[["valid_time", f"cl_{ilevel}", lag1_col, lag2_col]]) + + cor_len_df = records[0] + for df in records[1:]: + cor_len_df = cor_len_df.merge(df, on="valid_time", how="outer") + + cor_len_df = cor_len_df.sort_values( + "valid_time").dropna().reset_index(drop=True) + + bp_filter = filter_gaussian((n_rows, n_cols), n_levels, kmperpixel) + scales = 1 / bp_filter["central_freqs"] + log_scales = np.log(scales) + + cl_columns = [f"cl_{i}" for i in range(n_levels)] + cl_data = cor_len_df[cl_columns].values + valid_mask = cl_data > 0 + log_cl_data = np.where(valid_mask, np.log(cl_data), np.nan) + + x_vals = np.tile(log_scales, (log_cl_data.shape[0], 1)).flatten() + y_vals = log_cl_data.flatten() + valid_idx = ~np.isnan(y_vals) + x_valid, y_valid = x_vals[valid_idx], y_vals[valid_idx] + + X = sm.add_constant(x_valid) + model = sm.OLS(y_valid, X).fit() + a, b = model.params + + print(model.summary()) + + # Median correlation length per scale (ignoring NaNs) + median_cl = np.nanmedian(log_cl_data, axis=0) + + # Scatter plot: log-scale vs median log(correlation length) + plt.figure(figsize=(8, 5)) + plt.scatter(log_scales, median_cl, label="Median log(correlation length)", color='blue') + + # Regression line + x_fit = np.linspace(min(log_scales), max(log_scales), 100) + y_fit = a + b * x_fit + plt.plot(x_fit, y_fit, color='red', label=f"OLS fit: y = {a:.2f} + {b:.2f}x") + + # Labels and formatting + plt.xlabel("log(Spatial scale [km])") + plt.ylabel("log(Correlation length [km])") + plt.title("Median correlation length vs scale (log-log)") + plt.grid(True) + plt.legend() + plt.tight_layout() + plt.show() + + percentiles = [95, 50, 5] + cl0_values = cor_len_df["cl_0"].values + pvals = np.percentile(cl0_values, percentiles) + + conf_dir = os.path.join("..", "run") + conf_path = os.path.join(conf_dir, config_file_name) + logging.info(f"Writing output dynamic scaling config to {conf_path} ") + with open(conf_path, "w") as f: + dynamic_scaling_config = {"dynamic_scaling": { + "central_wave_lengths": scales.tolist(), + "space_time_exponent": float(b), + "lag2_constants": lag2_constants, + "lag2_exponents": lag2_exponents, + "cor_len_percentiles": percentiles, + "cor_len_pvals": pvals.tolist() + }} + json.dump(dynamic_scaling_config, f, indent=2) + + generate_qaqc_plots(cor_len_df, b, scales, + lag2_constants, lag2_exponents, n_levels) + + +if __name__ == "__main__": + main() diff --git a/pysteps/param/cascade_utils.py b/pysteps/param/cascade_utils.py new file mode 100644 index 000000000..3ad4f3ad3 --- /dev/null +++ b/pysteps/param/cascade_utils.py @@ -0,0 +1,39 @@ +import numpy as np + +def get_cascade_wavelengths(n_levels, domain_size_km, d=1.0, gauss_scale=0.5): + """ + Compute the central wavelengths (in km) for each cascade level. + + Parameters + ---------- + n_levels : int + Number of cascade levels. + domain_size_km : int or float + The larger of the two spatial dimensions (in km) of the domain. + d : float + Sample spacing (inverse of sampling rate). Default is 1. + gauss_scale : float + The Gaussian filter scaling parameter. + + Returns + ------- + wavelengths_km : np.ndarray + Central wavelengths in km for each cascade level (length = n_levels). + """ + # Compute q as in _gaussweights_1d + q = pow(0.5 * domain_size_km, 1.0 / n_levels) + + # Compute central wavenumbers (in grid units) + r = [(pow(q, k - 1), pow(q, k)) for k in range(1, n_levels + 1)] + central_wavenumbers = np.array([0.5 * (r0 + r1) for r0, r1 in r]) + + # Convert to frequency + central_freqs = central_wavenumbers / domain_size_km + central_freqs[0] = 1.0 / domain_size_km # enforce first freq > 0 + central_freqs[-1] = 0.5 # Nyquist limit + central_freqs *= d + + # Convert to wavelength (in km) + central_wavelengths_km = 1.0 / central_freqs + + return central_wavelengths_km diff --git a/pysteps/param/make_cascades.py b/pysteps/param/make_cascades.py new file mode 100644 index 000000000..d0b72ed2a --- /dev/null +++ b/pysteps/param/make_cascades.py @@ -0,0 +1,281 @@ +""" +Script to decompose and track the input rainfall fields and load them into the database + +""" + +from models import store_cascade_to_gridfs, replace_extension, read_nc +from models import get_db, get_config +from pymongo import MongoClient +import logging +import argparse +import gridfs +import pymongo +import numpy as np +import datetime +import os +import sys + +from pysteps import motion +from pysteps.utils import transformation +from pysteps.cascade.decomposition import decomposition_fft +from pysteps.cascade.bandpass_filters import filter_gaussian + +from urllib.parse import quote_plus + +WAR_THRESHOLD = 0.05 # Select only fields with rain for analysis + +def is_valid_iso8601(time_str: str) -> bool: + """Check if the given string is a valid ISO 8601 datetime.""" + try: + datetime.datetime.fromisoformat(time_str) + return True + except ValueError: + return False + + +def process_files(file_names: list[str], db: MongoClient, config: dict): + timestep = config["pysteps"]["timestep"] + db_zerovalue = config["pysteps"]["zerovalue"] + n_levels = config['pysteps']['n_cascade_levels'] + n_rows = config['domain']['n_rows'] + n_cols = config['domain']['n_cols'] + name = config['name'] + + oflow_method = motion.get_method("LK") # Lucas-Kanade method + bp_filter = filter_gaussian((n_rows, n_cols), n_levels) + + time_delta_tolerance = 120 + min_delta_time = datetime.timedelta( + seconds=timestep - time_delta_tolerance) + max_delta_time = datetime.timedelta( + seconds=timestep + time_delta_tolerance) + + rain_col_name = f"{name}.rain" + state_col_name = f"{name}.state" + rain_fs = gridfs.GridFS(db, collection=rain_col_name) + state_fs = gridfs.GridFS(db, collection=state_col_name) + + # Initialize buffers for batch processing + prev_time = None + cur_time = None + prev_field = None + cur_field = None + file_names.sort() + + for file_name in file_names: + grid_out = rain_fs.find_one({"filename": file_name}) + if grid_out is None: + logging.warning(f"File {file_name} not found in GridFS, skipping.") + continue + + # Extract metadata safely + rain_fs_metadata = grid_out.metadata if hasattr(grid_out, "metadata") else {} + + if not rain_fs_metadata: + logging.warning(f"No metadata found for {file_name}, skipping.") + continue + + try: + + # Copy relevant metadata from rain_fs (MongoDB) to state_fs + field_metadata = { + "filename": replace_extension(grid_out.filename, ".npz"), + "product": rain_fs_metadata.get("product", "unknown"), + "domain": rain_fs_metadata.get("domain", "AKL"), + "ensemble": rain_fs_metadata.get("ensemble", None), + "base_time": rain_fs_metadata.get("base_time", None), + "valid_time": rain_fs_metadata.get("valid_time", None), + "mean": rain_fs_metadata.get("mean", 0), + "std_dev": rain_fs_metadata.get("std_dev", 0), + "wetted_area_ratio": rain_fs_metadata.get("wetted_area_ratio", 0) + } + + # Check if cascade already exists for this file + filename = replace_extension(grid_out.filename, ".npz") + existing_file = state_fs.find_one({"filename": filename}) + if existing_file: + state_fs.delete(existing_file._id) + + # Read the input NetCDF file + in_buffer = grid_out.read() + rain_geodata, valid_time, rain_data = read_nc(in_buffer) + + # Transform the field to dB if needed + if rain_geodata.get("transform") is None: + db_data, db_geodata = transformation.dB_transform( + rain_data, rain_geodata, threshold=0.1, zerovalue=db_zerovalue + ) + db_data[~np.isfinite(db_data)] = db_geodata["zerovalue"] + else: + db_data = rain_data.copy() + db_geodata = rain_geodata.copy() + + # Perform cascade decomposition + cascade_dict = decomposition_fft( + db_data, bp_filter, compute_stats=True, normalize=True + ) + + # Add the rain field transformation for the cascade + cascade_dict["transform"] = "dB" + cascade_dict["zerovalue"] = db_zerovalue + cascade_dict["threshold"] = -10 # Assumes db_transform threshold = 0.1 + + # Compute optical flow + if prev_time is None: + prev_time = valid_time + cur_time = valid_time + prev_field = db_data + cur_field = db_data + else: + prev_time = cur_time + prev_field = cur_field + cur_time = valid_time + cur_field = db_data + + # Compute motion field if the time difference is in the acceptable range + V1 = np.zeros((2, n_rows, n_cols)) + tdiff = cur_time - prev_time + if min_delta_time < tdiff < max_delta_time: + R = np.array([prev_field, cur_field]) + V1 = oflow_method(R) + + + # Store cascade and motion field in GridFS with metadata + store_cascade_to_gridfs( + db, name, cascade_dict, V1, field_metadata["filename"], field_metadata) + + except Exception as e: + logging.error(f"Error processing {grid_out.filename}: {e}") + + +def main(): + parser = argparse.ArgumentParser( + description="Decompose and track rainfall fields") + + + parser.add_argument('-s', '--start', type=str, required=True, + help='Start time yyyy-mm-ddTHH:MM:SS') + parser.add_argument('-e', '--end', type=str, required=True, + help='End time yyyy-mm-ddTHH:MM:SS') + parser.add_argument('-n', '--name', type=str, required=True, + help='Name of domain [AKL]') + parser.add_argument('-p', '--product', type=str, required=True, + help='Name of input product [QPE, auckprec]') + + args = parser.parse_args() + + # Include app name (module name) in log output + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s [%(levelname)s] %(name)s: %(message)s', + stream=sys.stdout + ) + + logger = logging.getLogger(__name__) + logger.info("Starting cascade generation process") + + # Validate start and end time and read them in + if args.start and is_valid_iso8601(args.start): + start_time = datetime.datetime.fromisoformat(str(args.start)) + if start_time.tzinfo is None: + start_time = start_time.replace(tzinfo=datetime.timezone.utc) + else: + logging.error( + "Invalid start time format. Please provide a valid ISO 8601 time string.") + return + + if args.end and is_valid_iso8601(args.end): + end_time = datetime.datetime.fromisoformat(str(args.end)) + if end_time.tzinfo is None: + end_time = end_time.replace(tzinfo=datetime.timezone.utc) + else: + logging.error( + "Invalid start time format. Please provide a valid ISO 8601 time string.") + return + + name = str(args.name) + product = str(args.product) + if product not in ["QPE", "auckprec", "qpesim"]: + logging.error( + "Invalid product. Please provide either 'QPE' or 'auckprec'.") + return + + db = get_db() + config = get_config(db,name) + meta_coll = db[f"{name}.rain.files"] + + # Single pass through the data for qpe product + if product == "QPE": + f_filter = { + "metadata.product": product, + "metadata.valid_time": {"$gte": start_time, "$lte": end_time}, + "metadata.wetted_area_ratio": {"$gte": WAR_THRESHOLD} + } + + fields = {"_id": 0, "filename": 1, "metadata.wetted_area_ratio": 1} + results = meta_coll.find(filter=f_filter, projection=fields).sort( + "filename", pymongo.ASCENDING) + if results is None: + logging.error( + f"Failed to find {product}data for {start_time} - {end_time}") + return + + file_names = [doc["filename"] for doc in results] + logging.info( + f"Found {len(file_names)} {product} fields to process between {start_time} and {end_time}") + process_files(file_names, db, config) + else: + # Get the list of unique nwp run times in this period in ascending order + base_time_query = { + "metadata.product": product, + "metadata.base_time": {"$gte": start_time, "$lte": end_time} + } + base_times = meta_coll.distinct("metadata.base_time", base_time_query) + if base_times is None: + logging.error( + f"Failed to find {product} data for {start_time} - {end_time}") + return + + base_times.sort() + logging.info( + f"Found {len(base_times)} {product} NWP runs to process between {start_time} and {end_time}") + + for base_time in base_times: + logging.info(f"Processing NWP run {base_time}") + + # Get the list of unique ensembles found at base_time + ensembles = meta_coll.distinct( + "metadata.ensemble", {"metadata.product": product, "metadata.base_time": base_time}) + + if not ensembles: + logging.warning( + f"No ensembles found for base_time {base_time}") + continue # Skip this base_time if no ensembles exist + + logging.info( + f"Found {len(ensembles)} ensembles for base_time {base_time}") + ensembles.sort() + + for ensemble in ensembles: + # Get all the forecasts for this base_time and ensemble and process + + f_filter = { + "metadata.product": product, + "metadata.base_time": base_time, + "metadata.valid_time": {"$gte": start_time, "$lte": end_time}, + "metadata.ensemble": ensemble, + "metadata.wetted_area_ratio": {"$gte": WAR_THRESHOLD} + } + + fields = {"_id": 0, "filename": 1, + "metadata.wetted_area_ratio": 1} + results = meta_coll.find(filter=f_filter, projection=fields).sort( + "filename", pymongo.ASCENDING) + file_names = [doc["filename"] for doc in results] + + if len(file_names) > 0: + process_files(file_names, db, config) + + +if __name__ == "__main__": + main() diff --git a/pysteps/param/make_parameters.py b/pysteps/param/make_parameters.py new file mode 100644 index 000000000..34900c3b5 --- /dev/null +++ b/pysteps/param/make_parameters.py @@ -0,0 +1,413 @@ +""" +make_parameters.py +=================== + +Script to estimate the following STEPS parameters and place then in a MongoDB collection: +correl: lag 1 and 2 auto correlations for the cascade levels +b1, b2, l1: Slope of isotropic power spectrum above and below the scale l1 for rainfall field +mean, variance, wetted are ratio of the rainfall field +pdist: Sample cumulative probability distribution of rainfall field + +""" + +from models import read_nc, get_states, compute_field_parameters, get_db +from models import compute_field_stats, correlation_length + +import datetime +import numpy as np +import io +import pymongo +import gridfs +import argparse +import logging +from pymongo import MongoClient + +from pysteps.utils import transformation +from pysteps import extrapolation + +from urllib.parse import quote_plus +import os +import sys + +WAR_THRESHOLD = 0.05 # Select only fields with rain for analysis + + +def is_valid_iso8601(time_str: str) -> bool: + """Check if the given string is a valid ISO 8601 datetime.""" + try: + datetime.datetime.fromisoformat(time_str) + return True + except ValueError: + return False + + +def lagr_auto_cor(data: np.ndarray, oflow: np.ndarray, config: dict): + """ + Generate the Lagrangian auto correlations for STEPS cascades. + + Args: + data (np.ndarray): [T, L, M, N] where: + - T = ar_order + 1 (number of time steps) + - L = number of cascade levels + - M, N = spatial dimensions. + oflow (np.ndarray): [2, M, N] Optical flow vectors. + config (dict): Configuration dictionary containing: + - "n_cascade_levels": Number of cascade levels (L). + - "ar_order": Autoregressive order (1 or 2). + - "extrapolation_method": Method for extrapolating fields. + + Returns: + np.ndarray: Autocorrelation coefficients of shape (L, ar_order). + """ + + n_cascade_levels = config["pysteps"]["n_cascade_levels"] + ar_order = config["pysteps"]["ar_order"] + e_method = config["pysteps"]["extrapolation_method"] + + if data.shape[0] < (ar_order + 1): + raise ValueError( + f"Insufficient time steps. Expected at least {ar_order + 1}, got {data.shape[0]}.") + + extrapolation_method = extrapolation.get_method(e_method) + autocorrelation_coefficients = np.full( + (n_cascade_levels, ar_order), np.nan) + + for level in range(n_cascade_levels): + lag_1 = extrapolation_method(data[-2, level], oflow, 1)[0] + lag_1 = np.where(np.isfinite(lag_1), lag_1, 0) + + data_t = np.where(np.isfinite(data[-1, level]), data[-1, level], 0) + if np.std(lag_1) > 1e-1 and np.std(data_t) > 1e-1: + autocorrelation_coefficients[level, 0] = np.corrcoef( + lag_1.flatten(), data_t.flatten())[0, 1] + + if ar_order == 2: + lag_2 = extrapolation_method(data[-3, level], oflow, 1)[0] + lag_2 = np.where(np.isfinite(lag_2), lag_2, 0) + + lag_1 = extrapolation_method(lag_2, oflow, 1)[0] + lag_1 = np.where(np.isfinite(lag_1), lag_1, 0) + + if np.std(lag_1) > 1e-1 and np.std(data_t) > 1e-1: + autocorrelation_coefficients[level, 1] = np.corrcoef( + lag_1.flatten(), data_t.flatten())[0, 1] + + return autocorrelation_coefficients + + +def process_files(file_names, db, config: dict): + """ + Loop over a list of files and calculate the STEPS parameters. + + Args: + file_names (list[str]): List of files to process + data_base (pymongo.MongoClient): MongoDB database + config (dict): Dictionary with pysteps configuration + + Returns: + list[dict]: List of steps parameter dictionaries + """ + ar_order = config["pysteps"]["ar_order"] + timestep = config["pysteps"]["timestep"] + time_step_mins = config["pysteps"]["timestep"] // 60 + db_zerovalue = config["pysteps"]["zerovalue"] + db_threshold = config["pysteps"]["threshold"] + scale_break = config['pysteps']["scale_break"] + kmperpixel = config['pysteps']["kmperpixel"] + name = config['name'] + + delta_time_step = datetime.timedelta(seconds=timestep) + + rain_col_name = f"{name}.rain" + rain_fs = gridfs.GridFS(db, collection=rain_col_name) + + params = [] + lag_2, lag_1, lag_0 = None, None, None + oflow_0 = None + + for file_name in file_names: + field = rain_fs.find_one({"filename": file_name}) + if field is None: + continue + + # Set up the field metadata + valid_time = field.metadata["valid_time"] + valid_time = valid_time.replace(tzinfo=datetime.timezone.utc) + base_time = field.metadata["base_time"] + if base_time is not None: + base_time = base_time.replace(tzinfo=datetime.timezone.utc) + + ensemble = field.metadata["ensemble"] + product = field.metadata["product"] + metadata = { + "field_id": field._id, + "product": product, + "base_time": base_time, + "ensemble": ensemble, + "valid_time": valid_time, + "kmperpixel": kmperpixel # Need this when generating the stochastic fields + } + + # Read in the rain field + in_buffer = field.read() + rain_geodata, _, rain_data = read_nc(in_buffer) + # Needs to be consistent with db_threshold = -10 + rain_geodata["threshold"] = 0.1 + rain_geodata["zerovalue"] = 0 + rain_stats = compute_field_stats(rain_data, rain_geodata) + + if rain_geodata["transform"] is None: + db_data, db_geodata = transformation.dB_transform( + rain_data, rain_geodata, threshold=0.1, zerovalue=db_zerovalue + ) + db_data[~np.isfinite(db_data)] = db_geodata["zerovalue"] + else: + db_data = rain_data.copy() + + db_geodata["threshold"] = db_threshold + db_geodata["zerovalue"] = db_zerovalue + db_stats = compute_field_stats(db_data, db_geodata) + + # Compute the power spectrum and prob dist + steps_params = compute_field_parameters( + db_data, db_geodata, scale_break) + + # Read in the cascades and calculate Lagr auto correlation + cascade = {} + + # Only calculate cascade parameters if enough rain + if rain_stats["nonzero_fraction"] < 0.05: + steps_params.update({ + "metadata": metadata, + "rain_stats": rain_stats, + "dbr_stats": db_stats, + "cascade": None + }) + params.append(steps_params) + continue + + # Fetch states for (t, t-1, t-2) + query = { + "metadata.product": product, + "metadata.valid_time": {"$in": [valid_time, valid_time - delta_time_step, valid_time - 2 * delta_time_step]}, + "metadata.base_time": base_time, + "metadata.ensemble": ensemble + } + + states = get_states(db, name, query) + + # Set up the keys for the states + bkey = "NA" if base_time is None else base_time + ekey = "NA" if ensemble is None else ensemble + lag0_inx = (valid_time, bkey, ekey) + lag1_inx = (valid_time - delta_time_step, bkey, ekey) + lag2_inx = (valid_time - 2*delta_time_step, bkey, ekey) + + state = states.get(lag0_inx) + lag_0 = state["cascade"] if state is not None else None + oflow_0 = state["optical_flow"] if state is not None else None + state = states.get(lag1_inx) + lag_1 = state["cascade"] if state is not None else None + state = states.get(lag2_inx) + lag_2 = state["cascade"] if state is not None else None + + # set up the cascade level means and stds for valid_time + stds = lag_0.get("stds") if lag_0 else None + cascade["stds"] = stds + means = lag_0.get("means") if lag_0 else None + cascade["means"] = means + + # Calculate the Lagr auto correl if enough data + num_valid = sum(x is not None for x in [lag_2, lag_1, lag_0]) + if num_valid == ar_order + 1 and oflow_0 is not None: + data = np.array([lag_1["cascade_levels"], lag_0["cascade_levels"]]) if ar_order == 1 else np.array( + [lag_2["cascade_levels"], lag_1["cascade_levels"], lag_0["cascade_levels"]]) + auto_cor = lagr_auto_cor(data, oflow_0, config) + + # calculate the correlation lengths (minutes) + lag1_list = auto_cor[:, 0].tolist() + lag2_list = auto_cor[:, 1].tolist() if ar_order == 2 else [ + None] * len(lag1_list) + corl_list = [ + correlation_length(l1, l2, time_step_mins) + for l1, l2 in zip(lag1_list, lag2_list) + ] + + cascade.update({ + "lag1": lag1_list, + "lag2": lag2_list, + "corl": corl_list, + "corl_zero":corl_list[0] + }) + else: + cascade.update({ + "lag1": None, + "lag2": None, + "corl": None, + "corl_zero":None + }) + + steps_params.update({ + "metadata": metadata, + "rain_stats": rain_stats, + "dbr_stats": db_stats, + "cascade": cascade + }) + params.append(steps_params) + + return params + + +def main(): + + parser = argparse.ArgumentParser( + description="Calculate STEPS parameters") + + parser.add_argument('-s', '--start', type=str, required=True, + help='Start time yyyy-mm-ddTHH:MM:SS') + parser.add_argument('-e', '--end', type=str, required=True, + help='End time yyyy-mm-ddTHH:MM:SS') + parser.add_argument('-n', '--name', type=str, required=True, + help='Name of domain [AKL]') + parser.add_argument('-p', '--product', type=str, required=True, + help='Name of input product [QPE, auckprec, qpesim]') + + args = parser.parse_args() + + # Include app name (module name) in log output + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s [%(levelname)s] %(name)s: %(message)s', + stream=sys.stdout + ) + + logger = logging.getLogger(__name__) + logger.info("Calculating STEPS parameters") + + # Validate start and end time and read them in + if args.start and is_valid_iso8601(args.start): + start_time = datetime.datetime.fromisoformat(str(args.start)) + if start_time.tzinfo is None: + start_time = start_time.replace(tzinfo=datetime.timezone.utc) + else: + logging.error( + "Invalid start time format. Please provide a valid ISO 8601 time string.") + return + + if args.end and is_valid_iso8601(args.end): + end_time = datetime.datetime.fromisoformat(str(args.end)) + if end_time.tzinfo is None: + end_time = end_time.replace(tzinfo=datetime.timezone.utc) + else: + logging.error( + "Invalid start time format. Please provide a valid ISO 8601 time string.") + return + + name = str(args.name) + product = str(args.product) + if product not in ["QPE", "auckprec", "qpesim"]: + logging.error( + "Invalid product. Please provide either 'QPE', 'auckprec', or 'qpesim'.") + return + + db = get_db() + config_coll = db["config"] + record = config_coll.find_one({'config.name': name}, sort=[ + ('time', pymongo.DESCENDING)]) + if record is None: + logging.error(f"Could not find configuration for domain {name}") + return + + config = record['config'] + meta_coll = db[f"{name}.rain.files"] + params_coll = db[f"{name}.params"] + + # Single pass through the data for qpe product + if product == "QPE": + f_filter = { + "metadata.product": product, + "metadata.valid_time": {"$gte": start_time, "$lte": end_time}, + "metadata.wetted_area_ratio": {"$gte": WAR_THRESHOLD} + } + + fields = {"_id": 0, "filename": 1, "metadata.wetted_area_ratio": 1} + results = meta_coll.find(filter=f_filter, projection=fields).sort( + "filename", pymongo.ASCENDING) + if results is None: + logging.error( + f"Failed to find {product}data for {start_time} - {end_time}") + return + + file_names = [doc["filename"] for doc in results] + logging.info( + f"Found {len(file_names)} {product} fields to process between {start_time} and {end_time}") + steps_params = process_files(file_names, db, config) + + params_coll.delete_many({ + "metadata.product": product, + "metadata.valid_time": {"$gte": start_time, "$lte": end_time} + }) + if steps_params: + params_coll.insert_many(steps_params) + else: + # Get the list of unique nwp run times in this period in ascending order + start_base_time = start_time - datetime.timedelta(hours=12) + base_time_query = { + "metadata.product": product, + "metadata.base_time": {"$gte": start_base_time, "$lte": end_time} + } + base_times = meta_coll.distinct("metadata.base_time", base_time_query) + if base_times is None: + logging.error( + f"Failed to find {product} data for {start_time} - {end_time}") + return + + base_times.sort() + + for base_time in base_times: + + # Get the list of unique ensembles found at base_time + ensembles = meta_coll.distinct( + "metadata.ensemble", {"metadata.product": product, "metadata.base_time": base_time}) + + if not ensembles: + logging.warning( + f"No ensembles found for base_time {base_time}") + continue # Skip this base_time if no ensembles exist + + logging.info( + f"Found {len(ensembles)} ensembles for base_time {base_time}") + ensembles.sort() + + for ensemble in ensembles: + # Get all the forecasts for this base_time and ensemble and process + + f_filter = { + "metadata.product": product, + "metadata.valid_time": {"$gte": start_time, "$lte": end_time}, + "metadata.base_time": base_time, + "metadata.ensemble": ensemble, + "metadata.wetted_area_ratio": {"$gte": WAR_THRESHOLD} + } + + fields = {"_id": 0, "filename": 1, + "metadata.wetted_area_ratio": 1} + results = meta_coll.find(filter=f_filter, projection=fields).sort( + "filename", pymongo.ASCENDING) + file_names = [doc["filename"] for doc in results] + + if len(file_names) > 0: + steps_params = process_files( + file_names, db, config) + params_coll.delete_many({ + "metadata.product": product, + "metadata.valid_time": {"$gte": start_time, "$lte": end_time}, + "metadata.base_time": base_time, + "metadata.ensemble": ensemble, + }) + if steps_params: + params_coll.insert_many(steps_params) + + +if __name__ == "__main__": + main() diff --git a/pysteps/param/nwp_param_qc.py b/pysteps/param/nwp_param_qc.py new file mode 100644 index 000000000..6187d250c --- /dev/null +++ b/pysteps/param/nwp_param_qc.py @@ -0,0 +1,199 @@ +import argparse +import datetime +import logging +import numpy as np +import pandas as pd +from pymongo import UpdateOne +from models.mongo_access import get_db, get_config +from models.steps_params import power_law_acor, StochasticRainParameters + +from statsmodels.tsa.api import SimpleExpSmoothing +import pymongo.collection +from typing import Dict +logging.basicConfig(level=logging.INFO) + + +def get_parameters_df(query: Dict, param_coll: pymongo.collection.Collection) -> pd.DataFrame: + """ + Retrieve STEPS parameters from the database and return a DataFrame + indexed by (valid_time, base_time, ensemble), using 'NA' as sentinel for missing values. + + Args: + query (dict): MongoDB query dictionary. + param_coll (pymongo.collection.Collection): MongoDB collection. + + Returns: + pd.DataFrame: Indexed by (valid_time, base_time, ensemble), with a 'param' column. + """ + records = [] + + for doc in param_coll.find(query).sort("metadata.valid_time", pymongo.ASCENDING): + try: + metadata = doc.get("metadata", {}) + if metadata is None: + continue + + if doc["cascade"]["lag1"] is None or doc["cascade"]["lag2"] is None: + continue + + valid_time = metadata.get("valid_time") + valid_time = pd.to_datetime(valid_time,utc=True) + + base_time = metadata.get("base_time") + if base_time is None: + base_time = pd.NaT + else: + base_time = pd.to_datetime(base_time, utc=True) + + ensemble = metadata.get("ensemble") + + param = StochasticRainParameters.from_dict(doc) + + param.calc_corl() + records.append({ + "valid_time": valid_time, + "base_time": base_time, + "ensemble": ensemble, + "param": param + }) + except Exception as e: + print( + f"Warning: could not parse parameter for {metadata.get('valid_time')}: {e}") + + if not records: + return pd.DataFrame(columns=["valid_time", "base_time", "ensemble", "param"]) + + df = pd.DataFrame(records) + return df + +def parse_args(): + parser = argparse.ArgumentParser( + description="QC and update NWP lag autocorrelations") + parser.add_argument("-n", "--name", required=True, help="Domain name, e.g., AKL") + parser.add_argument("-p","--product", required=True, + help="Product name, e.g., auckprec") + parser.add_argument("-b","--base_time", required=True, + help="Base time, ISO format UTC (e.g., 2023-01-26T03:00:00)") + parser.add_argument("--dry_run", action="store_true", + help="Run without writing to database") + return parser.parse_args() + + +def qc_update_autocorrelations(dry_run: bool, name: str, product: str, base_time: datetime.datetime): + db = get_db() + config = get_config(db, name) + dt = datetime.timedelta(seconds=config["pysteps"]["timestep"]) + dt_seconds = dt.total_seconds() + + corl_pvals = config["dynamic_scaling"]["cor_len_pvals"] + corl_max = max(corl_pvals) + corl_min = min(corl_pvals) + + query = { + "metadata.product": product, + "metadata.base_time": base_time, + } + + param_coll = db[f"{name}.params"] + df = get_parameters_df(query, param_coll) + if df.empty: + logging.warning("No parameters found for the given base_time.") + return + + # Build corl_0 time series + records = [] + + ensembles = df["ensemble"].unique() + ensembles = np.sort(ensembles) + valid_times = df["valid_time"].unique() + t_min = min(valid_times) + t_max = max(valid_times) + all_times = pd.date_range(start=t_min, end=t_max, freq=dt, tz="UTC") + + # Convert the base_time to datetime64 for working with dataframe + vbase_time = pd.NaT + if base_time is not None: + vbase_time = pd.to_datetime(base_time,utc=True) + + for ens in ensembles: + ens_df = df.loc[ (df['base_time'] == vbase_time) & + (df['ensemble'] == ens),["valid_time", "param"] ].set_index("valid_time") + if ens_df.empty: + continue + + for vt in all_times: + try: + param = ens_df.loc[ vt,"param"] + corl_0 = param.corl_zero + + # Threshold at the 5 and 95 percentile values + corl_0 = corl_min if corl_0 < corl_min else corl_0 + corl_0 = corl_max if corl_0 > corl_max else corl_0 + except KeyError: + corl_0 = np.nan + + records.append({ + "valid_time": vt, + "ensemble": ens, + "corl_0": corl_0 + }) + + corl_df = pd.DataFrame.from_records(records) + corl_df = corl_df.sort_values(["ensemble", "valid_time"]) + updates = [] + + for ens in ensembles: + ens_df = corl_df[corl_df["ensemble"] == ens].set_index("valid_time") + if ens_df["corl_0"].isnull().all(): + logging.info(f"No valid corl_0 values for ensemble {ens}, skipping.") + continue + + mean_corl = ens_df["corl_0"].mean() + ens_df["corl_0"] = ens_df["corl_0"].fillna(mean_corl) + ens_df.index.freq = pd.Timedelta(seconds=dt_seconds) + + # Apply smoothing + model = SimpleExpSmoothing(ens_df["corl_0"], initialization_method="estimated").fit( + smoothing_level=0.2, optimized=False) + ens_df["corl_0_smoothed"] = model.fittedvalues + + for vt in ens_df.index: + T_ref = ens_df.loc[vt, "corl_0_smoothed"] + lags, corl = power_law_acor(config, T_ref) + valid_time = vt.to_pydatetime() + + updates.append(UpdateOne( + { + "metadata.product": product, + "metadata.valid_time": valid_time, + "metadata.base_time": base_time, + "metadata.ensemble": int(ens) + }, + { + "$set": { + "cascade.lag1": [float(x) for x in lags[:, 0]], + "cascade.lag2": [float(x) for x in lags[:, 1]], + "cascade.corl": [float(x) for x in corl], + "cascade.corl_zero":float(corl[0]) + } + }, + upsert=False + )) + if updates: + if dry_run: + logging.info( + f"{len(updates)} updates prepared (dry run, not written)") + else: + result = param_coll.bulk_write(updates) + logging.info(f"Updated {result.modified_count} documents.") + else: + logging.info("No documents to update.") + + +if __name__ == "__main__": + args = parse_args() + dry_run = args.dry_run + base_time = datetime.datetime.fromisoformat( + args.base_time).replace(tzinfo=datetime.timezone.utc) + qc_update_autocorrelations( + dry_run, args.name, args.product, base_time) diff --git a/pysteps/param/pysteps_param.py b/pysteps/param/pysteps_param.py new file mode 100644 index 000000000..8e09486ea --- /dev/null +++ b/pysteps/param/pysteps_param.py @@ -0,0 +1,369 @@ +from typing import List, Callable +import argparse +import logging +import datetime +import numpy as np +import copy +import os +import sys +import pandas as pd +from cascade.bandpass_filters import filter_gaussian +from utils import transformation +from cascade.decomposition import decomposition_fft +from mongo.nc_utils import generate_geo_data, make_nc_name_dt, write_netcdf +from mongo.gridfs_io import get_states, load_rain_field +from mongo.mongo_access import get_base_time, get_parameters_df +from steps_params import StochasticRainParameters, blend_parameters +from shared_utils import initialize_config +from shared_utils import zero_state, update_field + + +def get_weight(lag): + width = 3 * 3600 + weight = np.exp(-(lag/width)**2) + return weight + + +def main(): + + parser = argparse.ArgumentParser(description="Run nwpblend forecasts") + parser.add_argument('-b', '--base_time', required=True, + help='Base time in ISO 8601 format') + parser.add_argument('-n', '--name', required=True, + help='Domain name (e.g., AKL)') + args = parser.parse_args() + + + # Include app name (module name) in log output + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s [%(levelname)s] %(name)s: %(message)s', + stream=sys.stdout + ) + + logger = logging.getLogger(__name__) + logger.info("Gemerating nwpblend ensembles") + + name = args.name + db, config, out_base_time = initialize_config(args.base_time, name) + + param_coll = db[f"{name}.params"] + meta_coll = db[f"{name}.rain.files"] + + time_step_seconds = config['pysteps']['timestep'] + time_step = datetime.timedelta(seconds=time_step_seconds) + ar_order = config['pysteps']['ar_order'] + n_levels = config['pysteps']['n_cascade_levels'] + db_threshold = config['pysteps']['threshold'] + scale_break = config['pysteps']['scale_break'] + + # Set up the georeferencing data for the output forecasts + domain = config['domain'] + start_x = domain['start_x'] + start_y = domain['start_y'] + p_size = domain['p_size'] + n_rows = domain['n_rows'] + n_cols = domain['n_cols'] + x = [start_x + i * p_size for i in range(n_cols)] + y = [start_y + i * p_size for i in range(n_rows)] + geo_data = generate_geo_data(x, y) + geo_data["projection"] = config['projection']["epsg"] + + # Set up the bandpass filter + p_size_km = p_size / 1000.0 + bp_filter = filter_gaussian((n_rows, n_cols), n_levels, d=p_size_km) + + # Configure the output product + out_product = "nwpblend" + out_config = config['output'][out_product] + n_ens = out_config.get('n_ens_members', 10) + n_forecasts = out_config.get('n_forecasts', 12) + rad_product = out_config.get('rad_product', None) + nwp_product = out_config.get('nwp_product', None) + gridfs_out = out_config.get('gridfs_out', False) + nc_out = out_config.get('nc_out', False) + out_dir_name = out_config.get('out_dir_name', None) + out_file_name = out_config.get( + 'out_file_name', "$N_$P_$V{%Y-%m-%dT%H:%M:%S}_$B{%Y-%m-%dT%H:%M:%S}_$E.nc") + + # Validate the output configuration details + if rad_product is None: + logging.error(f"Radar product not specified") + return + if n_ens < 1: + logging.error(f"Invalid number of ensemble members: {n_ens}") + return + if n_forecasts < 1: + logging.error(f"Invalid number of lead times: {n_forecasts}") + return + if not gridfs_out and not nc_out: + logging.error( + "No output format specified. Please set either gridfs_out or nc_out to True.") + return + + if nc_out: + if out_dir_name is None: + logging.error(f"No output directory name found") + return + + logging.info(f"Generating nwpblend for {out_base_time}") + + # Make the list of output forecast times + forecast_times = [out_base_time + ia * + time_step for ia in range(0, n_forecasts+1)] + + + # Get the initial state(s) for the input radar field at this base time + # Set any missing states to None + base_time_key = "NA" + ensemble_key = "NA" + init_times = [out_base_time] + if ar_order == 2: + init_times = [out_base_time - time_step, out_base_time] + + query = { + "metadata.product": rad_product, + "metadata.valid_time": {"$in": init_times} + } + init_state = get_states(db, name, query, + get_cascade=True, get_optical_flow=True) + rad_params_df = get_parameters_df(query, param_coll) + + for vtime in init_times: + key = (vtime, base_time_key, ensemble_key) + if key not in init_state: + init_state[key] = zero_state(config) + logging.debug(f"Found missing QPE oflow for {vtime}") + + # Check if row exists for this combination + mask = ( + (rad_params_df["valid_time"] == vtime) & + (rad_params_df["base_time"] == "NA") & + (rad_params_df["ensemble"] == "NA") + ) + + if rad_params_df[mask].empty: + logging.debug(f"Found missing QPE parametersfor {vtime}") + + def_param = StochasticRainParameters() + def_param.calc_acor(config) + def_param.kmperpixel = p_size_km + def_param.scale_break = scale_break + def_param.threshold = db_threshold + + new_row = { + "valid_time": vtime, + "base_time": "NA", + "ensemble": "NA", + "param": def_param + } + rad_params_df = pd.concat([rad_params_df, pd.DataFrame([new_row])], ignore_index=True) + + # Get the base_time for the nwp run nearest to the output base_time + nwp_base_time = get_base_time(out_base_time, nwp_product, name, db) + + # Get the list of ensemble members for this nwp_base_time + query = { + "metadata.product": nwp_product, + "metadata.base_time": nwp_base_time} + nwp_ensembles = meta_coll.distinct("metadata.ensemble", query) + if nwp_ensembles is None: + logging.warning( + f"Failed to find ensembles for {nwp_product} data for {out_base_time}") + nwp_ensembles.sort() + n_nwp_ens = len(nwp_ensembles) + + # Get the NWP parameters and optical flows for the NWP ensemble + query = { + "metadata.product": nwp_product, + "metadata.valid_time": {"$in": forecast_times}, + 'metadata.base_time': nwp_base_time + } + nwp_params_df = get_parameters_df(query, param_coll) + nwp_oflows = get_states( + db, name, query, get_cascade=False, get_optical_flow=True) + + # Start the loop over the ensemble members + for iens in range(n_ens): + + # Calculate the set of blended parameters for this output ensemble + # Get the radar parameter + qpe_rows = rad_params_df[ + (rad_params_df["valid_time"] == out_base_time) & + (rad_params_df["base_time"] == "NA") & + (rad_params_df["ensemble"] == "NA") + ] + rad_param = qpe_rows.iloc[0]["param"] + + # Randomly select an ensemble member from the NWP + nwp_ens = np.random.randint(low=0, high=n_nwp_ens) + nwp_ensemble_df = nwp_params_df[ + (nwp_params_df["base_time"] == nwp_base_time) & + (nwp_params_df["ensemble"] == nwp_ens) + ][["valid_time", "param"]].copy() + nwp_ensemble_df["valid_time"] = pd.to_datetime(nwp_ensemble_df["valid_time"]) + nwp_ensemble_df.set_index("valid_time", inplace=True) + nwp_ensemble_df = nwp_ensemble_df.sort_index() + + # Fill in any missing forecast times with default parameters + for vtime in forecast_times: + if vtime not in nwp_ensemble_df.index: + def_param = StochasticRainParameters() + def_param.calc_acor(config) + def_param.kmperpixel = p_size_km + def_param.scale_break = scale_break + def_param.threshold = db_threshold + nwp_ensemble_df.loc[vtime,"param"] = def_param + + # Blend the parameters + blend_params_df = blend_parameters(config, out_base_time, nwp_ensemble_df, rad_param) + + # Set up the initial conditions for the forecast loop + # The order is [t-1, t0] in init_times for AR(2) + if ar_order == 1: + key = (init_times[0], "NA", "NA") + state = init_state.get(key) + + if state is not None: + cascade = state.get("cascade") + optical = state.get("optical_flow") + fx_cascades = [copy.deepcopy(cascade)] if cascade is not None else [None] + fx_oflow = copy.deepcopy(optical) if optical is not None else None + else: + fx_cascades = [None] + fx_oflow = None + + else: # AR(2) + key_0 = (init_times[0], "NA", "NA") + key_1 = (init_times[1], "NA", "NA") + + state_0 = init_state.get(key_0) + state_1 = init_state.get(key_1) + + if state_0 is not None and state_1 is not None: + casc_0 = state_0.get("cascade") + casc_1 = state_1.get("cascade") + optical = state_1.get("optical_flow") + + fx_cascades = [ + copy.deepcopy(casc_0) if casc_0 is not None else None, + copy.deepcopy(casc_1) if casc_1 is not None else None + ] + fx_oflow = copy.deepcopy(optical) if optical is not None else None + else: + fx_cascades = [None, None] + fx_oflow = None + + # Start the forecast loop + for ifx in range(1, n_forecasts+1): + valid_time = forecast_times[ifx] + fx_param = blend_params_df.loc[valid_time, "param"] + + fx_dbrain = update_field( + fx_cascades, fx_oflow, fx_param, bp_filter, config) + has_nan = np.isnan(fx_dbrain).any() if fx_dbrain is not None else True + + if has_nan : + fx_rain = np.zeros((n_rows, n_cols)) + else: + fx_rain, _ = transformation.dB_transform( + fx_dbrain, inverse=True, threshold=db_threshold, zerovalue=0) + + # Make the output file name + fx_file_name = make_nc_name_dt( + out_file_name, name, out_product, valid_time, out_base_time, iens) + + # Write the NetCDF data to a memoryview buffer + # This is an ugly hack on time zones + vtime = valid_time + if vtime.tzinfo is None: + vtime = vtime.replace(tzinfo=datetime.timezone.utc) + btime = out_base_time + if btime.tzinfo is None: + btime = btime.replace(tzinfo=datetime.timezone.utc) + vtime_stamp = vtime.timestamp() + + nc_buf = write_netcdf(fx_rain, geo_data, vtime_stamp) + + if gridfs_out: + # Create metadata + rain_mask = fx_rain.copy() + rain_mask[rain_mask < 1] = 0 + rain_mask[rain_mask > 0] = 1 + war = rain_mask.sum() / (n_cols * n_rows) + mean = np.nanmean(fx_rain) + std_dev = np.nanstd(fx_rain) + max = np.nanmax(fx_rain) + metadata = { + "product": out_product, + "domain": name, + "ensemble": int(iens), + "base_time": btime, + "valid_time": vtime, + "mean": float(mean), + "wetted_area_ratio": float(war), + "std_dev": float(std_dev), + "max": float(max), + "forecast_lead_time": int(ifx*time_step_seconds) + } + load_rain_field(db, name, fx_file_name, nc_buf, metadata) + + if nc_out: + fx_dir_name = make_nc_name_dt( + out_dir_name, name, out_product, valid_time, out_base_time, iens) + if not os.path.exists(fx_dir_name): + os.makedirs(fx_dir_name) + fx_file_path = os.path.join(fx_dir_name, fx_file_name) + with open(fx_file_path, 'wb') as f: + f.write(nc_buf.tobytes()) + + # Update the cascade state list for the next forecast step + if ar_order == 2: + # Push the cascade history down (t0 → t-1) + fx_cascades[0] = copy.deepcopy(fx_cascades[1]) + + # Update the latest cascade (t0) from current forecast brain + if fx_dbrain is not None: + if has_nan: + fx_cascades[1] = zero_state(config)["cascade"] + logging.warning(f"NaNs found for {valid_time}, {iens} ") + else: + fx_cascades[1] = decomposition_fft( + fx_dbrain, bp_filter, compute_stats=True, normalize=True + ) + else: + fx_cascades[1] = zero_state(config)["cascade"] + + elif ar_order == 1: + # Only update the current cascade + if fx_dbrain is not None: + if has_nan: + fx_cascades[0] = zero_state(config)["cascade"] + logging.warning(f"NaNs found for {valid_time}, {iens} ") + else: + fx_cascades[0] = decomposition_fft( + fx_dbrain, bp_filter, compute_stats=True, normalize=True + ) + else: + fx_cascades[0] = zero_state(config)["cascade"] + + # Update the optical flow field using radar–NWP blending + if ifx < n_forecasts: + rad_key = (out_base_time, "NA", "NA") + nwp_key = (out_base_time, nwp_base_time, nwp_ens) + + lag = (valid_time - out_base_time).total_seconds() + weight = get_weight(lag) + + # Check availability of both radar and NWP optical flows + rad_oflow = init_state.get(rad_key, {}).get("optical_flow") + nwp_oflow_entry = nwp_oflows.get(nwp_key) + nwp_oflow = nwp_oflow_entry.get("optical_flow") if nwp_oflow_entry else None + + if rad_oflow is not None and nwp_oflow is not None: + fx_oflow = weight * rad_oflow + (1 - weight) * nwp_oflow + else: + fx_oflow = None + +if __name__ == "__main__": + main() diff --git a/pysteps/param/shared_utils.py b/pysteps/param/shared_utils.py new file mode 100644 index 000000000..2d37199fa --- /dev/null +++ b/pysteps/param/shared_utils.py @@ -0,0 +1,181 @@ +import datetime +import logging +import numpy as np +from pysteps.cascade.decomposition import decomposition_fft, recompose_fft +from pysteps.timeseries import autoregression +from pysteps import extrapolation +from models.mongo_access import get_config, get_db +from models.steps_params import StochasticRainParameters +from models.stochastic_generator import gen_stoch_field, normalize_db_field + +def initialize_config(base_time_str, name): + try: + base_time = datetime.datetime.fromisoformat(base_time_str).replace(tzinfo=datetime.timezone.utc) + except ValueError: + raise ValueError(f"Invalid base time format: {base_time_str}") + + db = get_db() + config = get_config(db, name) + if config is None: + raise RuntimeError(f"Configuration not found for domain {name}") + + return db, config, base_time + + +def prepare_forecast_loop(db, config, base_time, name, product): + print(f"Running {product} for domain {name} at {base_time}") + # Placeholder for forecast generation logic + # This is where you'd insert the time loop, forecast logic, and output handling + pass + + +def update_field(cascades: list, optical_flow: np.ndarray, params: StochasticRainParameters, bp_filter: dict, config: dict) -> np.ndarray: + """ + Update a rainfall field using the parametric STEPS algorithm. + + Args: + cascades (list): List of rainfall cascades for previous ar_order time steps. + optical_flow (np.ndarray): Optical flow field for Lagrangian updates. + params (StochasticRainParameters): Parameters for the update. + bp_filter: Bandpass filter dictionary returned by pysteps.cascade.bandpass_filters.filter_gaussian + config: The configuration dictionary + + Returns: + np.ndarray: Updated rainfall field in decibels (dB) of rain intensity + """ + ar_order = config['pysteps']['ar_order'] + n_levels = config['pysteps']['n_cascade_levels'] + n_rows = config['domain']['n_rows'] + n_cols = config['domain']['n_cols'] + + # Ensure that we have valid input parameters + number_none_states = sum(1 for v in cascades if v is None) + if (number_none_states != 0) or (optical_flow is None) or (params is None): + logging.debug( + "Missing cascade values, skipping forecast.") + return None + + # Calculate the AR phi parameters, check if there any cascade parameters + if params.cascade_lag1 is None: + logging.debug( + "No valid cascade lag1 values found in the parameters. Skipping forecast.") + return None + if ar_order == 2 and params.cascade_lag2 is None: + logging.debug( + "No valid cascade lag2 values found in the parameters. Skipping forecast.") + return None + + # Check if the lag 1 and lag 2 are all valid + number_none_lag1 = sum(1 for v in params.cascade_lag1 if np.isnan(v)) + number_none_lag2 = 0 + if ar_order == 2: + number_none_lag2 = sum(1 for v in params.cascade_lag2 if np.isnan(v)) + + # Fill the lag1 and lag2 with the default parameters + if number_none_lag1 != 0 or number_none_lag2 != 0: + params.corl_zero = config["dynamic_scaling"]["cor_len_pvals"][1] + params.calc_acor(config) + + phi = np.zeros((n_levels, ar_order + 1)) + for ilev in range(n_levels): + gamma_1 = params.cascade_lag1[ilev] + if ar_order == 2: + gamma_2 = autoregression.adjust_lag2_corrcoef2( + gamma_1, params.cascade_lag2[ilev]) + phi[ilev] = autoregression.estimate_ar_params_yw( + [gamma_1, gamma_2]) + else: + phi[ilev] = autoregression.estimate_ar_params_yw( + [gamma_1]) + + # Generate the noise field and cascade + noise_field = gen_stoch_field(params, n_cols, n_rows) + max_dbr = 10*np.log10(150) + min_dbr = 10*np.log10(0.05) + noise_field = np.clip(noise_field, min_dbr, max_dbr) + noise_cascade = decomposition_fft( + noise_field, bp_filter, compute_stats=True, normalize=True) + + # Update the cascade + extrapolation_method = extrapolation.get_method("semilagrangian") + lag_0 = np.zeros((n_levels, n_rows, n_cols)) + if ar_order == 1: + lag_1 = cascades[0]["cascade_levels"] + else: + lag_2 = cascades[0]["cascade_levels"] + lag_1 = cascades[1]["cascade_levels"] + + # Loop over cascade levels + for ilev in range(n_levels): + # Set the outside pixels to zero + adv_lag1 = extrapolation_method( + lag_1[ilev], optical_flow, 1, outval=0)[0] + if ar_order == 1: + lag_0[ilev] = phi[ilev, 0] * adv_lag1 + \ + phi[ilev, 1] * noise_cascade["cascade_levels"][ilev] + + else: + # Set the outside pixels to zero + adv_lag2 = extrapolation_method( + lag_2[ilev], optical_flow, 2, outval=0)[1] + lag_0[ilev] = phi[ilev, 0] * adv_lag1 + phi[ilev, 1] * \ + adv_lag2 + phi[ilev, 2] * noise_cascade["cascade_levels"][ilev] + + # Make sure we have mean = 0, stdev = 1 + lev_mean = np.mean(lag_0) + lev_stdev = np.std(lag_0) + if lev_stdev > 1e-1: + lag_0 = (lag_0 - lev_mean)/lev_stdev + + # Recompose the cascade into a single field + updated_cascade = {} + updated_cascade["domain"] = "spatial" + updated_cascade["normalized"] = True + updated_cascade["compact_output"] = False + updated_cascade["cascade_levels"] = lag_0.copy() + + # Use the noise cascade level stds + updated_cascade["means"] = noise_cascade["means"].copy() + updated_cascade["stds"] = noise_cascade["stds"].copy() + + # Reduce the bias in the last cascade level due to the gradient in rain / no rain + high_freq_bias = 0.80 + updated_cascade["stds"][-1] *= high_freq_bias + gen_field = recompose_fft(updated_cascade) + + # Normalise the field to have the expected conditional mean and variance + norm_field = normalize_db_field(gen_field, params) + + return norm_field + +def zero_state(config): + n_cascade_levels = config['pysteps']['n_cascade_levels'] + n_rows = config['domain']['n_rows'] + n_cols = config['domain']['n_cols'] + metadata_dict = { + "transform": config['pysteps']['transform'], + "threshold": config['pysteps']['threshold'], + "zerovalue": config['pysteps']['zerovalue'], + "mean": float(0), + "std_dev": float(0), + "wetted_area_ratio": float(0) + } + cascade_dict = { + "cascade_levels": np.zeros((n_cascade_levels, n_rows, n_cols)), + "means": np.zeros(n_cascade_levels), + "stds": np.zeros(n_cascade_levels), + "domain": 'spatial', + "normalized": True, + } + oflow = np.zeros((2, n_rows, n_cols)) + state = { + "cascade": cascade_dict, + "optical_flow": oflow, + "metadata": metadata_dict + } + return state + + +def is_zero_state(state, tol=1e-6): + return abs(state["metadata"]["mean"]) < tol + diff --git a/pysteps/param/steps_params.py b/pysteps/param/steps_params.py new file mode 100644 index 000000000..16763008d --- /dev/null +++ b/pysteps/param/steps_params.py @@ -0,0 +1,555 @@ +# Contains: StochasticRainParameters (dataclass, from_dict, to_dict), compute_field_parameters, compute_field_stats +""" + Functions to implement the parametric version of STEPS +""" +from typing import Optional, Tuple, Dict, Union, List, Callable +import datetime +import copy +import logging +from typing import Optional, List, Dict, Any +from dataclasses import dataclass +import xarray as xr +from scipy.optimize import curve_fit +import numpy as np +import pandas as pd + +MAX_RAIN_RATE = 250 +N_BINS = 200 + +@dataclass +class StochasticRainParameters: + transform: Optional[str] = None + zerovalue: Optional[float] = None + threshold: Optional[float] = None + kmperpixel: Optional[float] = None + + mean_db: Optional[float] = None + stdev_db: Optional[float] = None + nonzero_mean_rain: Optional[float] = None + nonzero_stdev_rain: Optional[float] = None + mean_rain: Optional[float] = None + stdev_rain: Optional[float] = None + + psd: Optional[List[float]] = None + psd_bins: Optional[List[float]] = None + c1: Optional[float] = None + c2: Optional[float] = None + scale_break: Optional[float] = None + cdf: Optional[List[float]] = None + cdf_bins: Optional[List[float]] = None + cascade_stds: Optional[List[float]] = None + cascade_means: Optional[List[float]] = None + cascade_lag1: Optional[List[float]] = None + cascade_lag2: Optional[List[float]] = None + cascade_corl: Optional[List[float]] = None + + product: Optional[str] = None + valid_time: Optional[datetime.datetime] = None + base_time: Optional[datetime.datetime] = None + ensemble: Optional[int] = None + field_id: Optional[str] = None + + # Defaulted parameters + nonzero_mean_db: float = 2.3 + nonzero_stdev_db: float = 5.6 + rain_fraction: float = 0 + beta_1: float = -2.06 + beta_2: float = 3.2 + corl_zero: float = 260 + + def get(self, key: str, default: Any = None) -> Any: + """Mimic dict.get() for selected attributes.""" + return getattr(self, key, default) + + def calc_corl(self): + """Populate the correlation lengths using lag1 and lag2 values.""" + if self.cascade_lag1 is None or self.cascade_lag2 is None: + return + + n_levels = len(self.cascade_lag1) + if len(self.cascade_corl) != n_levels: + self.cascade_corl = [np.nan] * n_levels + + for ilev in range(n_levels): + lag1 = self.cascade_lag1[ilev] + lag2 = self.cascade_lag2[ilev] + self.cascade_corl[ilev] = correlation_length(lag1, lag2) + + # Convenience for blending with radar + self.corl_zero = self.cascade_corl[0] + + def calc_acor(self, config) -> None: + T_ref = self.corl_zero + if T_ref is None or np.isnan(T_ref): + T_ref = config["dynamic_scaling"]["cor_len_pvals"][1] + + acor, corl = power_law_acor(config, T_ref) + self.cascade_corl = [float(x) for x in corl] + self.cascade_lag1 = [float(x) for x in acor[:, 0]] + self.cascade_lag2 = [float(x) for x in acor[:, 1]] + + @classmethod + + def from_dict(cls, data: Dict[str, Any]) -> "StochasticRainParameters": + dbr = data.get("dbr_stats", {}) + rain = data.get("rain_stats", {}) + pspec = data.get("power_spectrum", {}) + model = pspec.get("model", {}) if pspec else {} + cdf_data = data.get("cdf", {}) + cascade = data.get("cascade", {}) + meta = data.get("metadata", {}) + + return cls( + product=meta.get("product"), + valid_time=meta.get("valid_time"), + base_time=meta.get("base_time"), + ensemble=meta.get("ensemble"), + field_id=meta.get("field_id"), + transform=dbr.get("transform"), + zerovalue=dbr.get("zerovalue"), + threshold=dbr.get("threshold"), + kmperpixel=meta.get("kmperpixel"), + + nonzero_mean_db=dbr.get("nonzero_mean"), + nonzero_stdev_db=dbr.get("nonzero_stdev"), + rain_fraction=dbr.get("nonzero_fraction"), + mean_db=dbr.get("mean"), + stdev_db=dbr.get("stdev"), + nonzero_mean_rain=rain.get("nonzero_mean"), + nonzero_stdev_rain=rain.get("nonzero_stdev"), + mean_rain=rain.get("mean"), + stdev_rain=rain.get("stdev"), + + psd=pspec.get("psd", []), + psd_bins=pspec.get("psd_bins", []), + + beta_1=model.get("beta_1"), + beta_2=model.get("beta_2"), + c1=model.get("c1"), + c2=model.get("c2"), + scale_break=model.get("scale_break"), + + cdf=cdf_data.get("cdf", []), + cdf_bins=cdf_data.get("cdf_bins", []), + + corl_zero=cascade.get("corl_zero"), + cascade_stds=cascade.get("stds"), + cascade_means=cascade.get("means"), + cascade_lag1=cascade.get("lag1"), + cascade_lag2=cascade.get("lag2"), + cascade_corl=[np.nan] * len(cascade.get("lag1", [])) + ) + + def to_dict(self) -> Dict[str, Any]: + return { + "dbr_stats": { + "transform": self.transform, + "zerovalue": self.zerovalue, + "threshold": self.threshold, + "nonzero_mean": self.nonzero_mean_db, + "nonzero_stdev": self.nonzero_stdev_db, + "nonzero_fraction": self.rain_fraction, + "mean": self.mean_db, + "stdev": self.stdev_db, + }, + "rain_stats": { + "nonzero_mean": self.nonzero_mean_rain, + "nonzero_stdev": self.nonzero_stdev_rain, + "nonzero_fraction": self.rain_fraction, # assume same as dbr_stats + "mean": self.mean_rain, + "stdev": self.stdev_rain, + "transform": None, + "zerovalue": 0, + "threshold": 0.1, + }, + "power_spectrum": { + "psd": self.psd, + "psd_bins": self.psd_bins, + "model": { + "beta_1": self.beta_1, + "beta_2": self.beta_2, + "c1": self.c1, + "c2": self.c2, + "scale_break": self.scale_break, + } if any(x is not None for x in [self.beta_1, self.beta_2, self.c1, self.c2, self.scale_break]) else None + }, + "cdf": { + "cdf": self.cdf, + "cdf_bins": self.cdf_bins, + }, + "cascade": { + "corl_zero":self.corl_zero, + "stds": self.cascade_stds, + "means": self.cascade_means, + "lag1": self.cascade_lag1, + "lag2": self.cascade_lag2, + "corl": self.cascade_corl, + } if self.cascade_stds is not None else None, + "metadata": { + "kmperpixel": self.kmperpixel, + "product": self.product, + "valid_time": self.valid_time, + "base_time": self.base_time, + "ensemble": self.ensemble, + "field_id": self.field_id, + } + } + + +def compute_field_parameters(db_data: np.ndarray, db_metadata: dict, scale_break_km: Optional[float] = None): + """ + Compute STEPS parameters for the dB transformed rainfall field + + Args: + db_data (np.ndarray): 2D field of dB-transformed rain. + db_metadata (dict): pysteps metadata dictionary. + + Returns: + dict: Dictionary containing STEPS parameters. + """ + + # Compute power spectrum model + if scale_break_km is not None: + scalebreak = scale_break_km * 1000.0 / db_metadata["xpixelsize"] + else: + scalebreak = None + ps_dataset, ps_model = power_spectrum_1D(db_data, scalebreak) + power_spectrum = { + "psd": ps_dataset.psd.values.tolist(), + "psd_bins": ps_dataset.psd_bins.values.tolist(), + "model": ps_model + } + + # Compute cumulative probability distribution + cdf_dataset = prob_dist(db_data, db_metadata) + cdf = { + "cdf": cdf_dataset.cdf.values.tolist(), + "cdf_bins": cdf_dataset.cdf_bins.values.tolist(), + } + + # Store parameters in a dictionary + steps_params = { + "timestamp": datetime.datetime.now(datetime.timezone.utc), + "power_spectrum": power_spectrum, + "cdf": cdf + } + return steps_params + + +def power_spectrum_1D(field: np.ndarray, scale_break: Optional[float] = None + ) -> Tuple[Optional[xr.Dataset], Optional[Dict[str, float]]]: + """ + Calculate the 1D isotropic power spectrum and fit a power law model. + + Args: + field (np.ndarray): 2D input field in [rows, columns] order. + scale_break (float, optional): Scale break in pixel units. If None, fit single line. + + Returns: + ps_dataset (xarray.Dataset): 1D isotropic power spectrum in dB. + model_params (dict): Dictionary with model parameters: beta_1, beta_2, c1, c2, scale_break + """ + min_stdev = 0.1 + mean = np.nanmean(field) + stdev = np.nanstd(field) + if stdev < min_stdev: + return None, None + + norm_field = (field - mean) / stdev + np.nan_to_num(norm_field, copy=False) + + field_fft = np.fft.rfft2(norm_field) + power_spectrum = np.abs(field_fft) ** 2 + + freq_x = np.fft.fftfreq(field.shape[1]) + freq_y = np.fft.fftfreq(field.shape[0]) + freq_r = np.sqrt(freq_x[:, None]**2 + freq_y[None, :]**2) + freq_r = freq_r[: field.shape[0] // 2, : field.shape[1] // 2] + power_spectrum = power_spectrum[: field.shape[0] // + 2, : field.shape[1] // 2] + + n_bins = power_spectrum.shape[0] + bins = np.logspace(np.log10(freq_r.min() + 1 / n_bins), + np.log10(freq_r.max()), num=n_bins) + bin_centers = (bins[:-1] + bins[1:]) / 2 + power_1d = np.zeros(len(bin_centers)) + + for i in range(len(bins) - 1): + mask = (freq_r >= bins[i]) & (freq_r < bins[i + 1]) + power_1d[i] = np.nanmean( + power_spectrum[mask]) if np.any(mask) else np.nan + + valid = (bin_centers > 0) & (~np.isnan(power_1d)) + bin_centers = bin_centers[valid] + power_1d = power_1d[valid] + + if len(bin_centers) == 0: + return None, None + + log_x = 10*np.log10(bin_centers) + log_y = 10*np.log10(power_1d) + + start_idx = 2 + end_idx = np.searchsorted(log_x, -4.0) + + model_params = {} + + if scale_break is None: + def str_line(X, m, c): return m * X + c + popt, _ = curve_fit( + str_line, log_x[start_idx:end_idx], log_y[start_idx:end_idx]) + beta_1, c1 = popt + beta_2 = None + c2 = None + sb_log = None + else: + sb_freq = 1.0 / scale_break + sb_log = 10*np.log10(sb_freq) + + def piecewise_linear(x, m1, m2, c1): + c2 = (m1 - m2) * sb_log + c1 + return np.where(x <= sb_log, m1 * x + c1, m2 * x + c2) + + popt, _ = curve_fit( + piecewise_linear, log_x[start_idx:end_idx], log_y[start_idx:end_idx]) + beta_1, beta_2, c1 = popt + c2 = (beta_1 - beta_2) * sb_log + c1 + + ps_dataset = xr.Dataset( + {"psd": (["bin"], log_y)}, + coords={"psd_bins": (["bin"], log_x)}, + attrs={"description": "1-D Isotropic power spectrum", "units": "dB"} + ) + + model_params = { + "beta_1": float(beta_1), + "beta_2": float(beta_2), + "c1": float(c1), + "c2": float(c2), + "scale_break": float(scale_break) + } + + return ps_dataset, model_params + + +def prob_dist(data: np.ndarray, metadata: dict): + """ + Calculate the cumulative probability distribution for rain > threshold for dB field + + Args: + data (np.ndarray): 2D field of dB-transformed rain. + metadata (dict): pysteps metadata dictionary. + + Returns: + tuple: + - xarray Dataset containing the cumulative probability distribution and bin edges + - fraction of field with rain > threshold (float) + """ + + rain_mask = data > metadata["zerovalue"] + + # Compute cumulative probability distribution + min_db = metadata["zerovalue"] + 0.1 + max_db = 10 * np.log10(MAX_RAIN_RATE) + bin_edges = np.linspace(min_db, max_db, N_BINS) + + # Histogram of rain values + hist, _ = np.histogram(data[rain_mask], bins=bin_edges, density=True) + + # Compute cumulative distribution + cumulative_distr = np.cumsum(hist) / np.sum(hist) + + # Create an xarray Dataset to store both cumulative distribution and bin edges + cdf_dataset = xr.Dataset( + { + "cdf": (["bin"], cumulative_distr), + }, + coords={ + # bin_edges[:-1] to match the histogram bins + "cdf_bins": (["bin"], bin_edges[:-1]), + }, + attrs={ + "description": "Cumulative probability distribution of rain rates", + "units": "dB", + } + ) + + return cdf_dataset + + +def compute_field_stats(data, geodata): + nonzero_mask = data > geodata["zerovalue"] + nonzero_mean = np.mean(data[nonzero_mask]) if np.any( + nonzero_mask) else np.nan + nonzero_stdev = np.std(data[nonzero_mask]) if np.any( + nonzero_mask) else np.nan + nonzero_frac = np.sum(nonzero_mask) / data.size + mean_rain = np.nanmean(data) + stdev_rain = np.nanstd(data) + + rain_stats = { + "nonzero_mean": float(nonzero_mean) if nonzero_mean is not None else None, + "nonzero_stdev": float(nonzero_stdev) if nonzero_stdev is not None else None, + "nonzero_fraction": float(nonzero_frac) if nonzero_frac is not None else None, + "mean": float(mean_rain) if mean_rain is not None else None, + "stdev": float(stdev_rain) if stdev_rain is not None else None, + "transform": geodata["transform"], + "zerovalue": geodata["zerovalue"], + "threshold": geodata["threshold"] + } + return rain_stats + + +def get_param_by_key( + params_df: pd.DataFrame, + valid_time: datetime.datetime, + base_time: Optional[datetime.datetime] = None, + ensemble: Optional[Union[int, str]] = None, + strict: bool = False +) -> Optional[StochasticRainParameters]: + """ + Retrieve the StochasticRainParameters object from a DataFrame index. + + Uses 'NA' as sentinel for missing base_time/ensemble. + + Args: + params_df (pd.DataFrame): Indexed by (valid_time, base_time, ensemble). + valid_time (datetime): Required valid_time. + base_time (datetime or None): Optional base_time. + ensemble (int, str, or None): Optional ensemble. + strict (bool): Raise KeyError if not found (default: False = return None) + + Returns: + StochasticRainParameters or None + """ + base_time = base_time if base_time is not None else "NA" + ensemble = ensemble if ensemble is not None else "NA" + try: + return params_df.loc[(valid_time, base_time, ensemble), "param"] + except KeyError: + if strict: + raise + return None + + +def is_stationary(phi1, phi2): + return abs(phi2) < 1 and (phi1 + phi2) < 1 and (phi2 - phi1) < 1 + + +def correlation_length(lag1: float, lag2: float, dx=10, tol=1e-4, max_lag=1000): + """ + Calculate the correlation length in minutes assuming AR(2) process + Args: + lag1 (float): Lag 1 auto-correltion + lag2 (float): Lag 2 auto-correlation + dx (int, optional): time step between lag1 & 2 in minutes. Defaults to 10. + tol (float, optional): _description_. Defaults to 1e-4. + max_lag (int, optional): _description_. Defaults to 1000. + + Returns: + corl (float): Correlation length in minutes + np.nan on error + """ + if lag1 is None or lag2 is None: + return np.nan + + A = np.array([[1.0, lag1], [lag1, 1.0]]) + b = np.array([lag1, lag2]) + + try: + phi = np.linalg.solve(A, b) + except np.linalg.LinAlgError: + return np.nan + + phi1, phi2 = phi + if not is_stationary(phi1, phi2): + return np.nan + + rho_vals = [1.0, lag1, lag2] + for _ in range(3, max_lag): + next_rho = phi1 * rho_vals[-1] + phi2 * rho_vals[-2] + if abs(next_rho) < tol: + break + rho_vals.append(next_rho) + corl = np.trapz(rho_vals, dx=dx) + return corl + + +def power_law_acor(config: Dict[str, Any], T_ref: float) -> np.ndarray: + """ + Compute lag-1 and lag-2 autocorrelations for each cascade level using a power-law model. + + Args: + config (dict): Configuration dictionary with 'pysteps.timestep' (in seconds) + and 'dynamic_scaling' parameters. + T_ref (float): Reference correlation length T(t, L) at the largest scale (in minutes). + + Returns: + np.ndarray: Array of shape (n_levels, 2) with [lag1, lag2] for each level. + np.ndarray: Array of corelation lengths per level + """ + dt_seconds = config["pysteps"]["timestep"] + dt_mins = dt_seconds / 60.0 + + ds_config = config.get("dynamic_scaling", {}) + scales = ds_config["central_wave_lengths"] + ht = ds_config["space_time_exponent"] + a = ds_config["lag2_constants"] + b = ds_config["lag2_exponents"] + + L = scales[0] + T_levels = [T_ref * (l / L) ** ht for l in scales] + + lags = np.empty((len(scales), 2), dtype=np.float32) + for ia, T_l in enumerate(T_levels): + pl_lag1 = np.exp(-dt_mins / T_l) + pl_lag2 = a[ia] * (pl_lag1 ** b[ia]) + lags[ia, 0] = pl_lag1 + lags[ia, 1] = pl_lag2 + + return lags, T_levels + +def blend_param(qpe_params, nwp_params, param_names, weight): + for pname in param_names: + + qval = getattr(qpe_params, pname, None) + nval = getattr(nwp_params, pname, None) + if isinstance(qval, (int, float)) and isinstance(nval, (int, float)): + setattr(nwp_params, pname, weight * qval + (1 - weight) * nval) + elif isinstance(qval, list) and isinstance(nval, list) and len(qval) == len(nval): + setattr(nwp_params, pname, [ + weight * q + (1 - weight) * n for q, n in zip(qval, nval)]) + return nwp_params + + +def blend_parameters(config, blend_base_time: datetime.datetime, nwp_param_df: pd.DataFrame, rad_param: StochasticRainParameters, + weight_fn: Callable[[float], float] = None + ) -> pd.DataFrame: + + if weight_fn is None: + def weight_fn(lag_sec): return np.exp(-(lag_sec / 10800) + ** 2) # 3h Gaussian + blended_param_names = [ + "nonzero_mean_db", + "nonzero_stdev_db", + "rain_fraction", + "beta_1", + "beta_2", + "corl_zero" + ] + blended_df = copy.deepcopy(nwp_param_df) + for vtime in blended_df.index: + lag_sec = (vtime - blend_base_time).total_seconds() + weight = weight_fn(lag_sec) + + # Select the parameter object for this vtime and blend + original = blended_df.loc[vtime, "param"] + clean_original = copy.deepcopy(original) + updated = blend_param(rad_param, clean_original, blended_param_names, weight) + + # Update the auto-correlations using the dynamic scaling parameters + updated.calc_acor(config) + blended_df.loc[vtime, "param"] = updated + + return blended_df + diff --git a/pysteps/param/stochastic_generator.py b/pysteps/param/stochastic_generator.py new file mode 100644 index 000000000..15bd178d4 --- /dev/null +++ b/pysteps/param/stochastic_generator.py @@ -0,0 +1,195 @@ +# Contains: gen_stoch_field, normalize_db_field, pl_filter +from typing import Optional +import numpy as np +from scipy import interpolate, stats +from models import StochasticRainParameters + +def gen_stoch_field(steps_params: StochasticRainParameters, nx: int, ny: int): + """ + Generate a rain field with normal distribution and a power law power spectrum + Args: + steps_params (StochasticRainParameters): The dataclass with all the steps parameters + nx (int): x dimension of the output field + ny (int): y dimension of the output field + + Returns: + np.ndarray: Output field with shape (ny,nx) + """ + + beta_1 = steps_params.beta_1 + beta_2 = steps_params.beta_2 + pixel_size = steps_params.kmperpixel + scale_break = pixel_size * steps_params.scale_break + + # generate uniform random numbers in the range 0,1 + y = np.random.uniform(low=0, high=1, size=(ny, nx)) + + # Power law filter the field + fft = np.fft.fft2(y, (ny, nx)) + filter = pl_filter(beta_1, nx, ny, pixel_size, beta_2, scale_break) + out_fft = fft * filter + out_field = np.fft.ifft2(out_fft).real + + nbins = 250 + res = stats.cumfreq(out_field, numbins=nbins) + bins = [res.lowerlimit + ia * + res.binsize for ia in range(1+res.cumcount.size)] + count = res.cumcount / res.cumcount[nbins-1] + + # find the threshold value for this non-rain probability + rain_bin = 0 + for ia in range(nbins): + if count[ia] <= 1 - steps_params.rain_fraction: + rain_bin = ia + else: + break + rain_threshold = bins[rain_bin] + + # Shift the data to have the correct probability > 0 + norm_data = out_field - rain_threshold + + # Now we need to transform the "raining" samples to have the desired distribution + rain_mask = norm_data > steps_params.threshold + rain_obs = norm_data[rain_mask] + rain_res = stats.cumfreq(rain_obs, numbins=nbins) + rain_bins = [rain_res.lowerlimit + ia * + rain_res.binsize for ia in range(1+rain_res.cumcount.size)] + rain_cdf = rain_res.cumcount / rain_res.cumcount[nbins-1] + + # rain_bins are the bin edges; use bin centers for interpolation + bin_centers = 0.5 * (np.array(rain_bins[:-1]) + np.array(rain_bins[1:])) + + # Step 1: Build LUT: map empirical CDF → target normal quantiles + # Make sure rain_cdf values are in (0,1) to avoid issues with extreme tails + eps = 1e-6 + rain_cdf_clipped = np.clip(rain_cdf, eps, 1 - eps) + + # Map rain_cdf quantiles to corresponding values in the target normal distribution + target_mu = steps_params.nonzero_mean_db + target_sigma = steps_params.nonzero_stdev_db * 0.80 + normal_values = stats.norm.ppf( + rain_cdf_clipped, loc=target_mu, scale=target_sigma) + + # Create interpolation function from observed rain values to target normal values + cdf_transform = interpolate.interp1d( + bin_centers, normal_values, + kind="linear", bounds_error=False, + fill_value=(normal_values[0], normal_values[-1]) + ) + + # Transform raining pixels + norm_data[rain_mask] = cdf_transform(norm_data[rain_mask]) + return norm_data + + +def normalize_db_field(data, params): + if params.rain_fraction < 0.025: + return np.full_like(data, params.zerovalue) + + nbins = 250 + res = stats.cumfreq(data, numbins=nbins) + bins = [res.lowerlimit + ia * + res.binsize for ia in range(1+res.cumcount.size)] + count = res.cumcount / res.cumcount[nbins-1] + + # find the threshold value for this non-rain probability + rain_bin = 0 + for ia in range(nbins): + if count[ia] <= 1 - params.rain_fraction: + rain_bin = ia + else: + break + rain_threshold = bins[rain_bin+1] + + # Shift the data to have the correct probability of rain + norm_data = data + (params.threshold - rain_threshold) + + # Now we need to transform the raining samples to have the desired distribution + # Get the sample distribution + rain_mask = norm_data > params.threshold + rain_obs = norm_data[rain_mask] + rain_res = stats.cumfreq(rain_obs, numbins=nbins) + rain_bins = [rain_res.lowerlimit + ia * + rain_res.binsize for ia in range(1+rain_res.cumcount.size)] + rain_cdf = rain_res.cumcount / rain_res.cumcount[nbins-1] + + # rain_bins are the bin edges; use bin centers for interpolation + bin_centers = 0.5 * (np.array(rain_bins[:-1]) + np.array(rain_bins[1:])) + + # Step 1: Build LUT: map empirical CDF → target normal quantiles + # Make sure rain_cdf values are in (0,1) to avoid issues with extreme tails + eps = 5e-3 + rain_cdf_clipped = np.clip(rain_cdf, eps, 1 - eps) + + # Map rain_cdf quantiles to corresponding values in the target normal distribution + # We need to reduce the bias in the output fields + bias_adj = 0.85 + target_mu = params.nonzero_mean_db + target_sigma = params.nonzero_stdev_db * bias_adj + normal_values = stats.norm.ppf( + rain_cdf_clipped, loc=target_mu, scale=target_sigma) + + # Create interpolation function from observed rain values to target normal values + cdf_transform = interpolate.interp1d( + bin_centers, normal_values, + kind="linear", bounds_error=False, + fill_value=(normal_values[0], normal_values[-1]) + ) + + # Transform raining pixels + norm_data[rain_mask] = cdf_transform(norm_data[rain_mask]) + return norm_data + +def pl_filter(beta_1: float, nx: int, ny: int, pixel_size: float, beta_2: Optional[float] = None, scale_break: Optional[float] = None, + ): + """ + Generate a 2D low-pass power-law filter for FFT filtering. + + Parameters: + beta_1 (float): Power law exponent for frequencies < f1 (low frequencies) + nx (int): Number of columns (width) in the 2D field + ny (int): Number of rows (height) in the 2D field + pixel_size (float): Pixel size in km + beta_2 (float): Power law exponent for frequencies > f1 (high frequencies) Optional + scale_break (float): Break scale in km Optional + + Returns: + np.ndarray: 2D FFT low-pass filter + """ + + # Compute the frequency grid + freq_x = np.fft.fftfreq(nx, d=pixel_size) # Frequency in x-direction + freq_y = np.fft.fftfreq(ny, d=pixel_size) # Frequency in y-direction + + # 2D array with radial frequency + freq_r = np.sqrt(freq_x[:, None] ** 2 + freq_y[None, :] ** 2) + + # Initialize the radial 2D filter + filter_r = np.ones_like(freq_r) # Initialize with ones + f_zero = freq_x[1] + + if beta_2 is not None: + b1 = beta_1 / 2.0 + b2 = (beta_2-0.3) / 2.0 + + f1 = 1 / scale_break # Convert scale break to frequency domain + weight = (f1/f_zero) ** b1 + + # Apply the power-law function for a **low-pass filter** + # Handle division by zero at freq = 0 + with np.errstate(divide='ignore', invalid='ignore'): + mask_low = freq_r < f1 # Frequencies lower than the break + mask_high = ~mask_low # Frequencies higher than or equal to the break + + filter_r[mask_low] = (freq_r[mask_low]/f_zero) ** b1 + filter_r[mask_high] = weight * (freq_r[mask_high] / f1) ** b2 + + # Ensure DC component (zero frequency) is handled properly + filter_r[freq_r == 0] = 1 # Preserve the mean component + else: + b1 = beta_1 / 2.0 + mask = freq_r > 0 + filter_r[mask] = (freq_r[mask]/f_zero) ** b1 + filter_r[freq_r == 0] = 1 # Preserve the mean component + + return filter_r From ea7f7b4cebd211cb8cc83f7d39083682face5e42 Mon Sep 17 00:00:00 2001 From: alanseed Date: Tue, 15 Jul 2025 09:05:40 +1000 Subject: [PATCH 02/12] added transformer class --- pysteps/utils/__init__.py | 1 + pysteps/utils/transformer.py | 172 +++++++++++++++++++++++++++++++++++ 2 files changed, 173 insertions(+) create mode 100644 pysteps/utils/transformer.py diff --git a/pysteps/utils/__init__.py b/pysteps/utils/__init__.py index 9594a75ae..4348d4db4 100644 --- a/pysteps/utils/__init__.py +++ b/pysteps/utils/__init__.py @@ -12,3 +12,4 @@ from .tapering import * from .transformation import * from .reprojection import * +from .transformer import * diff --git a/pysteps/utils/transformer.py b/pysteps/utils/transformer.py new file mode 100644 index 000000000..bcb233e26 --- /dev/null +++ b/pysteps/utils/transformer.py @@ -0,0 +1,172 @@ +import numpy as np +import scipy.stats as scipy_stats +from scipy.interpolate import interp1d +from typing import Optional + +class BaseTransformer: + def __init__(self, threshold: float = 0.1, zerovalue: Optional[float] = None): + self.threshold = threshold + self.zerovalue = zerovalue + self.metadata = {} + + def transform(self, R: np.ndarray) -> np.ndarray: + raise NotImplementedError + + def inverse_transform(self, R: np.ndarray) -> np.ndarray: + raise NotImplementedError + + def get_metadata(self) -> dict: + return self.metadata.copy() + +class DBTransformer(BaseTransformer): + def transform(self, R: np.ndarray) -> np.ndarray: + R = R.copy() + mask = R < self.threshold + R[~mask] = 10.0 * np.log10(R[~mask]) + threshold_db = 10.0 * np.log10(self.threshold) + + if self.zerovalue is None: + self.zerovalue = threshold_db - 5 + + R[mask] = self.zerovalue + + self.metadata = { + "transform": "dB", + "threshold": threshold_db, + "zerovalue": self.zerovalue, + } + return R + + def inverse_transform(self, R: np.ndarray) -> np.ndarray: + R = R.copy() + threshold_lin = 10.0 ** (self.metadata["threshold"] / 10.0) + R = 10.0 ** (R / 10.0) + R[R < threshold_lin] = self.metadata["zerovalue"] + self.metadata["transform"] = None + return R + + +class BoxCoxTransformer(BaseTransformer): + def __init__(self, Lambda: float = 0.0, **kwargs): + super().__init__(**kwargs) + self.Lambda = Lambda + + def transform(self, R: np.ndarray) -> np.ndarray: + R = R.copy() + mask = R < self.threshold + + if self.Lambda == 0.0: + R[~mask] = np.log(R[~mask]) + tval = np.log(self.threshold) + else: + R[~mask] = (R[~mask] ** self.Lambda - 1) / self.Lambda + tval = (self.threshold ** self.Lambda - 1) / self.Lambda + + if self.zerovalue is None: + self.zerovalue = tval - 1 + + R[mask] = self.zerovalue + + self.metadata = { + "transform": "BoxCox", + "lambda": self.Lambda, + "threshold": tval, + "zerovalue": self.zerovalue, + } + return R + + def inverse_transform(self, R: np.ndarray) -> np.ndarray: + R = R.copy() + if self.Lambda == 0.0: + R = np.exp(R) + else: + R = np.exp(np.log(self.Lambda * R + 1) / self.Lambda) + + threshold_inv = ( + np.exp(np.log(self.Lambda * self.metadata["threshold"] + 1) / self.Lambda) + if self.Lambda != 0.0 else + np.exp(self.metadata["threshold"]) + ) + + R[R < threshold_inv] = self.metadata["zerovalue"] + self.metadata["transform"] = None + return R + +class NQTransformer(BaseTransformer): + def __init__(self, a: float = 0.0, **kwargs): + super().__init__(**kwargs) + self.a = a + self._inverse_interp = None + + def transform(self, R: np.ndarray) -> np.ndarray: + R = R.copy() + shape = R.shape + R = R.ravel() + mask = ~np.isnan(R) + R_ = R[mask] + + n = R_.size + Rpp = ((np.arange(n) + 1 - self.a) / (n + 1 - 2 * self.a)) + Rqn = scipy_stats.norm.ppf(Rpp) + R_sorted = R_[np.argsort(R_)] + R_trans = np.interp(R_, R_sorted, Rqn) + + self.zerovalue = np.min(R_) + R_trans[R_ == self.zerovalue] = 0 + + self._inverse_interp = interp1d( + Rqn, R_sorted, bounds_error=False, + fill_value=(float(R_sorted.min()), float(R_sorted.max())) # type: ignore + ) + + R[mask] = R_trans + R = R.reshape(shape) + + self.metadata = { + "transform": "NQT", + "threshold": R_trans[R_trans > 0].min(), + "zerovalue": 0, + } + return R + + def inverse_transform(self, R: np.ndarray) -> np.ndarray: + if self._inverse_interp is None: + raise RuntimeError("Must call transform() before inverse_transform()") + + R = R.copy() + shape = R.shape + R = R.ravel() + mask = ~np.isnan(R) + R[mask] = self._inverse_interp(R[mask]) + R = R.reshape(shape) + + self.metadata["transform"] = None + return R + +class SqrtTransformer(BaseTransformer): + def transform(self, R: np.ndarray) -> np.ndarray: + R = np.sqrt(R) + self.metadata = { + "transform": "sqrt", + "threshold": np.sqrt(self.threshold), + "zerovalue": np.sqrt(self.zerovalue) if self.zerovalue else 0.0 + } + return R + + def inverse_transform(self, R: np.ndarray) -> np.ndarray: + R = R**2 + self.metadata["transform"] = None + return R + +def get_transformer(name: str, **kwargs) -> BaseTransformer: + name = name.lower() + if name == "boxcox": + return BoxCoxTransformer(**kwargs) + elif name == "db": + return DBTransformer(**kwargs) + elif name == "nqt": + return NQTransformer(**kwargs) + elif name == "sqrt": + return SqrtTransformer(**kwargs) + else: + raise ValueError(f"Unknown transformer type: {name}") From cf2aaf3120d3cb40798c39cf9a0a15105d524ab2 Mon Sep 17 00:00:00 2001 From: alanseed Date: Wed, 6 Aug 2025 10:08:09 +1000 Subject: [PATCH 03/12] improved metadata management --- pysteps/utils/transformer.py | 34 +++++++++++++++++++++------------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/pysteps/utils/transformer.py b/pysteps/utils/transformer.py index bcb233e26..42f7475ef 100644 --- a/pysteps/utils/transformer.py +++ b/pysteps/utils/transformer.py @@ -4,7 +4,7 @@ from typing import Optional class BaseTransformer: - def __init__(self, threshold: float = 0.1, zerovalue: Optional[float] = None): + def __init__(self, threshold: float = 0.5, zerovalue: Optional[float] = None): self.threshold = threshold self.zerovalue = zerovalue self.metadata = {} @@ -19,30 +19,38 @@ def get_metadata(self) -> dict: return self.metadata.copy() class DBTransformer(BaseTransformer): - def transform(self, R: np.ndarray) -> np.ndarray: - R = R.copy() - mask = R < self.threshold - R[~mask] = 10.0 * np.log10(R[~mask]) + """ + DBTransformer applies a thresholded dB transform to rain rate fields. + + Parameters: + threshold (float): Rain rate threshold (in mm/h). Values below this are set to `zerovalue` in dB. + zerovalue (Optional[float]): Value in dB space to assign below-threshold pixels. If None, defaults to log10(threshold) - 0.1 + """ + + def __init__(self, threshold: float = 0.5, zerovalue: Optional[float] = None): + super().__init__(threshold, zerovalue) threshold_db = 10.0 * np.log10(self.threshold) if self.zerovalue is None: - self.zerovalue = threshold_db - 5 - - R[mask] = self.zerovalue + self.zerovalue = threshold_db - 0.1 self.metadata = { "transform": "dB", - "threshold": threshold_db, - "zerovalue": self.zerovalue, + "threshold": self.threshold, # stored in mm/h + "zerovalue": self.zerovalue # stored in dB } + + def transform(self, R: np.ndarray) -> np.ndarray: + R = R.copy() + mask = R < self.threshold + R[~mask] = 10.0 * np.log10(R[~mask]) + R[mask] = self.zerovalue return R def inverse_transform(self, R: np.ndarray) -> np.ndarray: R = R.copy() - threshold_lin = 10.0 ** (self.metadata["threshold"] / 10.0) R = 10.0 ** (R / 10.0) - R[R < threshold_lin] = self.metadata["zerovalue"] - self.metadata["transform"] = None + R[R < self.threshold] = 0 return R From fef0c06b232aeb1b9180153ee4830c0f1d1ccb1c Mon Sep 17 00:00:00 2001 From: Alan Seed Date: Fri, 29 Aug 2025 11:42:29 +1000 Subject: [PATCH 04/12] update param modules --- pysteps/param/README.md | 45 -- pysteps/param/__init__.py | 23 + pysteps/param/calibrate_ar_model.py | 321 ------------- pysteps/param/cascade_utils.py | 81 +++- pysteps/param/make_cascades.py | 281 ----------- pysteps/param/make_parameters.py | 413 ----------------- pysteps/param/nc_utils.py | 107 +++++ pysteps/param/nwp_param_qc.py | 199 -------- pysteps/param/pysteps_param.py | 369 --------------- pysteps/param/rainfield_stats.py | 509 ++++++++++++++++++++ pysteps/param/shared_utils.py | 638 ++++++++++++++++++++----- pysteps/param/steps_params.py | 643 +++++--------------------- pysteps/param/stochastic_generator.py | 148 +++--- pysteps/param/transformer.py | 187 ++++++++ 14 files changed, 1607 insertions(+), 2357 deletions(-) delete mode 100644 pysteps/param/README.md create mode 100644 pysteps/param/__init__.py delete mode 100644 pysteps/param/calibrate_ar_model.py delete mode 100644 pysteps/param/make_cascades.py delete mode 100644 pysteps/param/make_parameters.py create mode 100644 pysteps/param/nc_utils.py delete mode 100644 pysteps/param/nwp_param_qc.py delete mode 100644 pysteps/param/pysteps_param.py create mode 100644 pysteps/param/rainfield_stats.py create mode 100644 pysteps/param/transformer.py diff --git a/pysteps/param/README.md b/pysteps/param/README.md deleted file mode 100644 index fb5a56c84..000000000 --- a/pysteps/param/README.md +++ /dev/null @@ -1,45 +0,0 @@ -# param - -## executable scripts - -### pysteps_param.py - -This is the main script to generate an ensemble nowcast using the parametric algoithms - -### make_cascades.py - -Script to decompose and track the rainfall fields. The cascade states are written back into a GridFS bucket for later processing. - -### make_parameters.py - -Script to read the rainfall and cascade state data and calculate the STEPS parameters. The parameters are written back into a Mongo collection. - -### nwp_param_qc.py - -The NWP rainfall fields are derived from interpolating hourly ensembles onto a 10-min, 2 km resolution and contain significant errors as a result. This script cleans and smoothes the parameters that have been derived from the NWP ensemble and makes them ready for use. - -### calibrate_ar_model.py - -This scripts reads the radar rain fields and calibrates the dynamic scaling model. The output is a set of figures for quality assurance and a JSON file with the model parameters that can be included in the main configuration JSON file. - -## Modules - -### broken_line.py - -Implementation of the broken line model, not used at this stage but could be used to generate time series of STEPS parameters in the future. - -### cascade_utils.py - -A simple function to calculate the scale for each cascade level - -### shared_utils.py - -Functions that are likely to be used by the various forms of pySTEPS_param - -### steps_params.py - -Data class to manage the parameters and functions to operate on them - -### stochastic_generator.py - -Function to generate a single stochastic field given the parameters diff --git a/pysteps/param/__init__.py b/pysteps/param/__init__.py new file mode 100644 index 000000000..a48e10992 --- /dev/null +++ b/pysteps/param/__init__.py @@ -0,0 +1,23 @@ +from .steps_params import StepsParameters +from .rainfield_stats import ( + RainfieldStats, + compute_field_parameters, + compute_field_stats, + power_spectrum_1D, + correlation_length, + power_law_acor, +) +from .stochastic_generator import gen_stoch_field, normalize_db_field, pl_filter +from .cascade_utils import calculate_wavelengths, lagr_auto_cor +from .shared_utils import ( + qc_params, + update_field, + blend_parameters, + zero_state, + is_zero_state, + calc_corls, + fit_auto_cors, + calculate_parameters, +) +from .nc_utils import generate_geo_dict, generate_geo_dict_xy, read_qpe_netcdf +from .transformer import DBTransformer diff --git a/pysteps/param/calibrate_ar_model.py b/pysteps/param/calibrate_ar_model.py deleted file mode 100644 index 6c0de1dca..000000000 --- a/pysteps/param/calibrate_ar_model.py +++ /dev/null @@ -1,321 +0,0 @@ -""" -Estimate the parameters that manage the AR model using the observed QPE data -Write out the dynamic scaling configuration to a JSON file -""" - -from models import get_db, get_config -import datetime -import numpy as np -import pymongo -import argparse -import logging -import pandas as pd -from scipy.optimize import curve_fit -from pathlib import Path -import statsmodels.api as sm -from pysteps.cascade.bandpass_filters import filter_gaussian -import json -import matplotlib.pyplot as plt - - -def is_valid_iso8601(time_str: str) -> bool: - try: - datetime.datetime.fromisoformat(time_str) - return True - except ValueError: - return False - - -def get_auto_corls(db: pymongo.MongoClient, product: str, start_time: datetime.datetime, end_time: datetime.datetime): - params_col = db["AKL.params"] - query = { - 'metadata.valid_time': {'$gte': start_time, '$lte': end_time}, - 'metadata.product': product - } - projection = {"_id": 0, "metadata": 1, "cascade": 1} - data_cursor = params_col.find(query, projection=projection) - data_list = list(data_cursor) - - logging.info(f'Found {len(data_list)} documents') - - rows = [] - for doc in data_list: - row = {"valid_time": doc["metadata"]["valid_time"]} - - lag1 = doc.get("cascade", {}).get("lag1") - lag2 = doc.get("cascade", {}).get("lag2") - stds = doc.get("cascade", {}).get("stds") - - if lag1 is None or lag2 is None or stds is None: - continue - - for ia, val in enumerate(lag1): - row[f"lag1_{ia}"] = val - for ia, val in enumerate(lag2): - row[f"lag2_{ia}"] = val - for ia, val in enumerate(stds): - row[f"stds_{ia}"] = val - - rows.append(row) - - return pd.DataFrame(rows) - - -def power_law(x, a, b): - return a * np.power(x, b) - - -def fit_power_law(qpe_df, lev): - lag1_vals = qpe_df[f"lag1_{lev}"].values - lag2_vals = qpe_df[f"lag2_{lev}"].values - - q05 = np.quantile(lag1_vals, 0.05) - mask = lag1_vals > q05 - x = lag1_vals[mask] - y = lag2_vals[mask] - - coefs, _ = curve_fit(power_law, x, y) - - y_pred = power_law(x, coefs[0], coefs[1]) - ss_res = np.sum((y - y_pred) ** 2) - ss_tot = np.sum((y - np.mean(y)) ** 2) - r_squared = 1 - ss_res / ss_tot - - return (coefs[0], coefs[1], r_squared) - -def is_stationary(phi1, phi2): - return abs(phi2) < 1 and (phi1 + phi2) < 1 and (phi2 - phi1) < 1 - -def correlation_length(lag1, lag2, tol=1e-4, max_lag=1000): - if lag1 is None or lag2 is None: - return np.nan - - A = np.array([[1.0, lag1], [lag1, 1.0]]) - b = np.array([lag1, lag2]) - - try: - phi = np.linalg.solve(A, b) - except np.linalg.LinAlgError: - return np.nan - - phi1, phi2 = phi - if not is_stationary(phi1, phi2): - return np.nan - - rho_vals = [1.0, lag1, lag2] - for _ in range(3, max_lag): - next_rho = phi1 * rho_vals[-1] + phi2 * rho_vals[-2] - if abs(next_rho) < tol: - break - rho_vals.append(next_rho) - - return np.trapz(rho_vals, dx=10) - -import os -def generate_qaqc_plots(cor_len_df, ht, scales, lag2_constants, lag2_exponents, n_levels, output_prefix=""): - # Set up color map - cmap = plt.colormaps.get_cmap('tab10') - - # Create output directory if it doesn't exist - figs_dir = os.path.join("..", "figs") - os.makedirs(figs_dir, exist_ok=True) - - # Plot fitted vs observed lag1 & lag2 for three percentiles of cl_0 - percentiles = [95, 50, 5] - cl0_values = cor_len_df["cl_0"].values - pvals = np.percentile(cl0_values, percentiles) - L = scales[0] # reference scale in km - - for pval, pstr in zip(pvals, percentiles): - # Find closest row - idx = (np.abs(cl0_values - pval)).argmin() - row = cor_len_df.iloc[idx] - time_str = row["valid_time"].strftime("%Y-%m-%d %H:%M") - - # Set up the scaling correlation lengths for this case - T_ref = row["cl_0"] # T(t, L) at largest scale - T_levels = [T_ref * (l / L) ** ht for l in scales] - dt = 10 - - obs_lag1 = [] - obs_lag2 = [] - fit_lag1 = [] - fit_lag2 = [] - levels = [] - for ilevel in range(n_levels): - lag1 = row[f"lag1_{ilevel}"] - lag2 = row[f"lag2_{ilevel}"] - - a = lag2_constants[ilevel] - b = lag2_exponents[ilevel] - pl_lag1 = np.exp(-dt / T_levels[ilevel]) - pl_lag2 = a * (pl_lag1 ** b) - obs_lag1.append(lag1) - obs_lag2.append(lag2) - fit_lag1.append(pl_lag1) - fit_lag2.append(pl_lag2) - levels.append(ilevel) - - plt.figure(figsize=(6, 4)) - color_lag1 = cmap(1) - color_lag2 = cmap(2) - - plt.plot(scales, obs_lag1, 'x-', label='Observed lag1', color=color_lag1) - plt.plot(scales, fit_lag1, 'o-', label='Fit lag1',color= color_lag1) - plt.plot(scales, obs_lag2, 'x--', label='Observed lag2', color=color_lag2) - plt.plot(scales, fit_lag2, 'o--', label='Fit lag2', color=color_lag2) - - plt.xscale("log") - plt.xlabel("Scale (km)") - plt.ylabel("Autocorrelation") - plt.title(f"Fit vs Obs @ cl_0 ~ {pstr}th percentile \n{time_str}, corl len = {T_ref:.0f} min") - plt.grid(True, which="both", ls="--", alpha=0.6) - plt.legend() - plt.tight_layout() - - filename = f"{output_prefix}lags_{pstr}th_percentile.png" - plt.savefig(os.path.join(figs_dir, filename)) - plt.close() - - -def main(): - parser = argparse.ArgumentParser( - description="Calculate the parameters for the dynamic scaling model") - parser.add_argument('-s', '--start', type=str, - required=True, help='Start time yyyy-mm-ddTHH:MM:SS') - parser.add_argument('-e', '--end', type=str, required=True, - help='End time yyyy-mm-ddTHH:MM:SS') - parser.add_argument('-n', '--name', type=str, - required=True, help='Name of domain [AKL]') - parser.add_argument('-c', '--config', type=Path, required=True, - help='Path to output dynamic scaling configuration file') - args = parser.parse_args() - - logging.basicConfig(level=logging.INFO) - - # Parse start and end time - try: - start_time = datetime.datetime.fromisoformat( - args.start).replace(tzinfo=datetime.timezone.utc) - end_time = datetime.datetime.fromisoformat( - args.end).replace(tzinfo=datetime.timezone.utc) - except ValueError: - logging.error("Invalid ISO 8601 date format for start or end time.") - return - - name = args.name - config_file_name = args.config - product = "QPE" - - db = get_db() - config = get_config(db, name) - n_rows = config["domain"]["n_rows"] - n_cols = config["domain"]["n_cols"] - n_levels = config["pysteps"]["n_cascade_levels"] - kmperpixel = config["pysteps"]["kmperpixel"] - - corl_df = get_auto_corls(db, product, start_time, end_time).dropna() - - if corl_df.empty: - logging.error( - "No valid correlation data found in the selected time range.") - return - - lag2_constants, lag2_exponents = [], [] - - for ilevel in range(n_levels): - a, b, rsq = fit_power_law(corl_df, ilevel) - if rsq < 0.5: - logging.info( - f"Warning: Rsq = {rsq:.2f}, using default power law for level {ilevel}") - a, b = 1.0, 2.4 - logging.info( - f"Level {ilevel}: lag2 = {a:.3f} * lag1^{b:.3f}, R² = {rsq:.2f}") - lag2_constants.append(a) - lag2_exponents.append(b) - - records = [] - for ilevel in range(n_levels): - lag1_col = f"lag1_{ilevel}" - lag2_col = f"lag2_{ilevel}" - - level_df = corl_df[["valid_time", lag1_col, lag2_col]].copy() - level_df["pl_lag2"] = lag2_constants[ilevel] * \ - np.power(level_df[lag1_col], lag2_exponents[ilevel]) - level_df[f"cl_{ilevel}"] = level_df.apply( - lambda row: correlation_length(row[lag1_col], row["pl_lag2"]), axis=1) - records.append( - level_df[["valid_time", f"cl_{ilevel}", lag1_col, lag2_col]]) - - cor_len_df = records[0] - for df in records[1:]: - cor_len_df = cor_len_df.merge(df, on="valid_time", how="outer") - - cor_len_df = cor_len_df.sort_values( - "valid_time").dropna().reset_index(drop=True) - - bp_filter = filter_gaussian((n_rows, n_cols), n_levels, kmperpixel) - scales = 1 / bp_filter["central_freqs"] - log_scales = np.log(scales) - - cl_columns = [f"cl_{i}" for i in range(n_levels)] - cl_data = cor_len_df[cl_columns].values - valid_mask = cl_data > 0 - log_cl_data = np.where(valid_mask, np.log(cl_data), np.nan) - - x_vals = np.tile(log_scales, (log_cl_data.shape[0], 1)).flatten() - y_vals = log_cl_data.flatten() - valid_idx = ~np.isnan(y_vals) - x_valid, y_valid = x_vals[valid_idx], y_vals[valid_idx] - - X = sm.add_constant(x_valid) - model = sm.OLS(y_valid, X).fit() - a, b = model.params - - print(model.summary()) - - # Median correlation length per scale (ignoring NaNs) - median_cl = np.nanmedian(log_cl_data, axis=0) - - # Scatter plot: log-scale vs median log(correlation length) - plt.figure(figsize=(8, 5)) - plt.scatter(log_scales, median_cl, label="Median log(correlation length)", color='blue') - - # Regression line - x_fit = np.linspace(min(log_scales), max(log_scales), 100) - y_fit = a + b * x_fit - plt.plot(x_fit, y_fit, color='red', label=f"OLS fit: y = {a:.2f} + {b:.2f}x") - - # Labels and formatting - plt.xlabel("log(Spatial scale [km])") - plt.ylabel("log(Correlation length [km])") - plt.title("Median correlation length vs scale (log-log)") - plt.grid(True) - plt.legend() - plt.tight_layout() - plt.show() - - percentiles = [95, 50, 5] - cl0_values = cor_len_df["cl_0"].values - pvals = np.percentile(cl0_values, percentiles) - - conf_dir = os.path.join("..", "run") - conf_path = os.path.join(conf_dir, config_file_name) - logging.info(f"Writing output dynamic scaling config to {conf_path} ") - with open(conf_path, "w") as f: - dynamic_scaling_config = {"dynamic_scaling": { - "central_wave_lengths": scales.tolist(), - "space_time_exponent": float(b), - "lag2_constants": lag2_constants, - "lag2_exponents": lag2_exponents, - "cor_len_percentiles": percentiles, - "cor_len_pvals": pvals.tolist() - }} - json.dump(dynamic_scaling_config, f, indent=2) - - generate_qaqc_plots(cor_len_df, b, scales, - lag2_constants, lag2_exponents, n_levels) - - -if __name__ == "__main__": - main() diff --git a/pysteps/param/cascade_utils.py b/pysteps/param/cascade_utils.py index 3ad4f3ad3..97e41104b 100644 --- a/pysteps/param/cascade_utils.py +++ b/pysteps/param/cascade_utils.py @@ -1,6 +1,8 @@ -import numpy as np +import numpy as np +from pysteps import extrapolation -def get_cascade_wavelengths(n_levels, domain_size_km, d=1.0, gauss_scale=0.5): + +def calculate_wavelengths(n_levels: int, domain_size: float, d: float = 1.0): """ Compute the central wavelengths (in km) for each cascade level. @@ -8,32 +10,79 @@ def get_cascade_wavelengths(n_levels, domain_size_km, d=1.0, gauss_scale=0.5): ---------- n_levels : int Number of cascade levels. - domain_size_km : int or float - The larger of the two spatial dimensions (in km) of the domain. + domain_size : int or float + The larger of the two spatial dimensions of the domain in pixels. d : float - Sample spacing (inverse of sampling rate). Default is 1. - gauss_scale : float - The Gaussian filter scaling parameter. + Sample frequency in pixels per km. Default is 1. Returns ------- wavelengths_km : np.ndarray Central wavelengths in km for each cascade level (length = n_levels). """ - # Compute q as in _gaussweights_1d - q = pow(0.5 * domain_size_km, 1.0 / n_levels) - + # Compute q + q = pow(0.5 * domain_size, 1.0 / n_levels) + # Compute central wavenumbers (in grid units) r = [(pow(q, k - 1), pow(q, k)) for k in range(1, n_levels + 1)] central_wavenumbers = np.array([0.5 * (r0 + r1) for r0, r1 in r]) - + # Convert to frequency - central_freqs = central_wavenumbers / domain_size_km - central_freqs[0] = 1.0 / domain_size_km # enforce first freq > 0 + central_freqs = central_wavenumbers / domain_size + central_freqs[0] = 1.0 / domain_size central_freqs[-1] = 0.5 # Nyquist limit - central_freqs *= d - # Convert to wavelength (in km) + # Convert wavelength to km, d is pixels per km + central_freqs = central_freqs * d central_wavelengths_km = 1.0 / central_freqs - return central_wavelengths_km + + +def lagr_auto_cor(data: np.ndarray, oflow: np.ndarray): + """ + Generate the Lagrangian auto correlations for STEPS cascades. + + Args: + data (np.ndarray): [T, L, M, N] where: + - T = ar_order + 1 (number of time steps) + - L = number of cascade levels + - M, N = spatial dimensions. + oflow (np.ndarray): [2, M, N] Optical flow vectors. + + Returns: + np.ndarray: Autocorrelation coefficients of shape (L, ar_order). + """ + ar_order = 2 + if data.shape[0] < (ar_order + 1): + raise ValueError( + f"Insufficient time steps. Expected at least {ar_order + 1}, got {data.shape[0]}." + ) + + n_cascade_levels = data.shape[1] + extrapolation_method = extrapolation.get_method("semilagrangian") + + autocorrelation_coefficients = np.full((n_cascade_levels, ar_order), np.nan) + + for level in range(n_cascade_levels): + lag_1 = extrapolation_method(data[-2, level], oflow, 1)[0] + lag_1 = np.where(np.isfinite(lag_1), lag_1, 0) + + data_t = np.where(np.isfinite(data[-1, level]), data[-1, level], 0) + if np.std(lag_1) > 1e-1 and np.std(data_t) > 1e-1: + autocorrelation_coefficients[level, 0] = np.corrcoef( + lag_1.flatten(), data_t.flatten() + )[0, 1] + + if ar_order == 2: + lag_2 = extrapolation_method(data[-3, level], oflow, 1)[0] + lag_2 = np.where(np.isfinite(lag_2), lag_2, 0) + + lag_1 = extrapolation_method(lag_2, oflow, 1)[0] + lag_1 = np.where(np.isfinite(lag_1), lag_1, 0) + + if np.std(lag_1) > 1e-1 and np.std(data_t) > 1e-1: + autocorrelation_coefficients[level, 1] = np.corrcoef( + lag_1.flatten(), data_t.flatten() + )[0, 1] + + return autocorrelation_coefficients diff --git a/pysteps/param/make_cascades.py b/pysteps/param/make_cascades.py deleted file mode 100644 index d0b72ed2a..000000000 --- a/pysteps/param/make_cascades.py +++ /dev/null @@ -1,281 +0,0 @@ -""" -Script to decompose and track the input rainfall fields and load them into the database - -""" - -from models import store_cascade_to_gridfs, replace_extension, read_nc -from models import get_db, get_config -from pymongo import MongoClient -import logging -import argparse -import gridfs -import pymongo -import numpy as np -import datetime -import os -import sys - -from pysteps import motion -from pysteps.utils import transformation -from pysteps.cascade.decomposition import decomposition_fft -from pysteps.cascade.bandpass_filters import filter_gaussian - -from urllib.parse import quote_plus - -WAR_THRESHOLD = 0.05 # Select only fields with rain for analysis - -def is_valid_iso8601(time_str: str) -> bool: - """Check if the given string is a valid ISO 8601 datetime.""" - try: - datetime.datetime.fromisoformat(time_str) - return True - except ValueError: - return False - - -def process_files(file_names: list[str], db: MongoClient, config: dict): - timestep = config["pysteps"]["timestep"] - db_zerovalue = config["pysteps"]["zerovalue"] - n_levels = config['pysteps']['n_cascade_levels'] - n_rows = config['domain']['n_rows'] - n_cols = config['domain']['n_cols'] - name = config['name'] - - oflow_method = motion.get_method("LK") # Lucas-Kanade method - bp_filter = filter_gaussian((n_rows, n_cols), n_levels) - - time_delta_tolerance = 120 - min_delta_time = datetime.timedelta( - seconds=timestep - time_delta_tolerance) - max_delta_time = datetime.timedelta( - seconds=timestep + time_delta_tolerance) - - rain_col_name = f"{name}.rain" - state_col_name = f"{name}.state" - rain_fs = gridfs.GridFS(db, collection=rain_col_name) - state_fs = gridfs.GridFS(db, collection=state_col_name) - - # Initialize buffers for batch processing - prev_time = None - cur_time = None - prev_field = None - cur_field = None - file_names.sort() - - for file_name in file_names: - grid_out = rain_fs.find_one({"filename": file_name}) - if grid_out is None: - logging.warning(f"File {file_name} not found in GridFS, skipping.") - continue - - # Extract metadata safely - rain_fs_metadata = grid_out.metadata if hasattr(grid_out, "metadata") else {} - - if not rain_fs_metadata: - logging.warning(f"No metadata found for {file_name}, skipping.") - continue - - try: - - # Copy relevant metadata from rain_fs (MongoDB) to state_fs - field_metadata = { - "filename": replace_extension(grid_out.filename, ".npz"), - "product": rain_fs_metadata.get("product", "unknown"), - "domain": rain_fs_metadata.get("domain", "AKL"), - "ensemble": rain_fs_metadata.get("ensemble", None), - "base_time": rain_fs_metadata.get("base_time", None), - "valid_time": rain_fs_metadata.get("valid_time", None), - "mean": rain_fs_metadata.get("mean", 0), - "std_dev": rain_fs_metadata.get("std_dev", 0), - "wetted_area_ratio": rain_fs_metadata.get("wetted_area_ratio", 0) - } - - # Check if cascade already exists for this file - filename = replace_extension(grid_out.filename, ".npz") - existing_file = state_fs.find_one({"filename": filename}) - if existing_file: - state_fs.delete(existing_file._id) - - # Read the input NetCDF file - in_buffer = grid_out.read() - rain_geodata, valid_time, rain_data = read_nc(in_buffer) - - # Transform the field to dB if needed - if rain_geodata.get("transform") is None: - db_data, db_geodata = transformation.dB_transform( - rain_data, rain_geodata, threshold=0.1, zerovalue=db_zerovalue - ) - db_data[~np.isfinite(db_data)] = db_geodata["zerovalue"] - else: - db_data = rain_data.copy() - db_geodata = rain_geodata.copy() - - # Perform cascade decomposition - cascade_dict = decomposition_fft( - db_data, bp_filter, compute_stats=True, normalize=True - ) - - # Add the rain field transformation for the cascade - cascade_dict["transform"] = "dB" - cascade_dict["zerovalue"] = db_zerovalue - cascade_dict["threshold"] = -10 # Assumes db_transform threshold = 0.1 - - # Compute optical flow - if prev_time is None: - prev_time = valid_time - cur_time = valid_time - prev_field = db_data - cur_field = db_data - else: - prev_time = cur_time - prev_field = cur_field - cur_time = valid_time - cur_field = db_data - - # Compute motion field if the time difference is in the acceptable range - V1 = np.zeros((2, n_rows, n_cols)) - tdiff = cur_time - prev_time - if min_delta_time < tdiff < max_delta_time: - R = np.array([prev_field, cur_field]) - V1 = oflow_method(R) - - - # Store cascade and motion field in GridFS with metadata - store_cascade_to_gridfs( - db, name, cascade_dict, V1, field_metadata["filename"], field_metadata) - - except Exception as e: - logging.error(f"Error processing {grid_out.filename}: {e}") - - -def main(): - parser = argparse.ArgumentParser( - description="Decompose and track rainfall fields") - - - parser.add_argument('-s', '--start', type=str, required=True, - help='Start time yyyy-mm-ddTHH:MM:SS') - parser.add_argument('-e', '--end', type=str, required=True, - help='End time yyyy-mm-ddTHH:MM:SS') - parser.add_argument('-n', '--name', type=str, required=True, - help='Name of domain [AKL]') - parser.add_argument('-p', '--product', type=str, required=True, - help='Name of input product [QPE, auckprec]') - - args = parser.parse_args() - - # Include app name (module name) in log output - logging.basicConfig( - level=logging.INFO, - format='%(asctime)s [%(levelname)s] %(name)s: %(message)s', - stream=sys.stdout - ) - - logger = logging.getLogger(__name__) - logger.info("Starting cascade generation process") - - # Validate start and end time and read them in - if args.start and is_valid_iso8601(args.start): - start_time = datetime.datetime.fromisoformat(str(args.start)) - if start_time.tzinfo is None: - start_time = start_time.replace(tzinfo=datetime.timezone.utc) - else: - logging.error( - "Invalid start time format. Please provide a valid ISO 8601 time string.") - return - - if args.end and is_valid_iso8601(args.end): - end_time = datetime.datetime.fromisoformat(str(args.end)) - if end_time.tzinfo is None: - end_time = end_time.replace(tzinfo=datetime.timezone.utc) - else: - logging.error( - "Invalid start time format. Please provide a valid ISO 8601 time string.") - return - - name = str(args.name) - product = str(args.product) - if product not in ["QPE", "auckprec", "qpesim"]: - logging.error( - "Invalid product. Please provide either 'QPE' or 'auckprec'.") - return - - db = get_db() - config = get_config(db,name) - meta_coll = db[f"{name}.rain.files"] - - # Single pass through the data for qpe product - if product == "QPE": - f_filter = { - "metadata.product": product, - "metadata.valid_time": {"$gte": start_time, "$lte": end_time}, - "metadata.wetted_area_ratio": {"$gte": WAR_THRESHOLD} - } - - fields = {"_id": 0, "filename": 1, "metadata.wetted_area_ratio": 1} - results = meta_coll.find(filter=f_filter, projection=fields).sort( - "filename", pymongo.ASCENDING) - if results is None: - logging.error( - f"Failed to find {product}data for {start_time} - {end_time}") - return - - file_names = [doc["filename"] for doc in results] - logging.info( - f"Found {len(file_names)} {product} fields to process between {start_time} and {end_time}") - process_files(file_names, db, config) - else: - # Get the list of unique nwp run times in this period in ascending order - base_time_query = { - "metadata.product": product, - "metadata.base_time": {"$gte": start_time, "$lte": end_time} - } - base_times = meta_coll.distinct("metadata.base_time", base_time_query) - if base_times is None: - logging.error( - f"Failed to find {product} data for {start_time} - {end_time}") - return - - base_times.sort() - logging.info( - f"Found {len(base_times)} {product} NWP runs to process between {start_time} and {end_time}") - - for base_time in base_times: - logging.info(f"Processing NWP run {base_time}") - - # Get the list of unique ensembles found at base_time - ensembles = meta_coll.distinct( - "metadata.ensemble", {"metadata.product": product, "metadata.base_time": base_time}) - - if not ensembles: - logging.warning( - f"No ensembles found for base_time {base_time}") - continue # Skip this base_time if no ensembles exist - - logging.info( - f"Found {len(ensembles)} ensembles for base_time {base_time}") - ensembles.sort() - - for ensemble in ensembles: - # Get all the forecasts for this base_time and ensemble and process - - f_filter = { - "metadata.product": product, - "metadata.base_time": base_time, - "metadata.valid_time": {"$gte": start_time, "$lte": end_time}, - "metadata.ensemble": ensemble, - "metadata.wetted_area_ratio": {"$gte": WAR_THRESHOLD} - } - - fields = {"_id": 0, "filename": 1, - "metadata.wetted_area_ratio": 1} - results = meta_coll.find(filter=f_filter, projection=fields).sort( - "filename", pymongo.ASCENDING) - file_names = [doc["filename"] for doc in results] - - if len(file_names) > 0: - process_files(file_names, db, config) - - -if __name__ == "__main__": - main() diff --git a/pysteps/param/make_parameters.py b/pysteps/param/make_parameters.py deleted file mode 100644 index 34900c3b5..000000000 --- a/pysteps/param/make_parameters.py +++ /dev/null @@ -1,413 +0,0 @@ -""" -make_parameters.py -=================== - -Script to estimate the following STEPS parameters and place then in a MongoDB collection: -correl: lag 1 and 2 auto correlations for the cascade levels -b1, b2, l1: Slope of isotropic power spectrum above and below the scale l1 for rainfall field -mean, variance, wetted are ratio of the rainfall field -pdist: Sample cumulative probability distribution of rainfall field - -""" - -from models import read_nc, get_states, compute_field_parameters, get_db -from models import compute_field_stats, correlation_length - -import datetime -import numpy as np -import io -import pymongo -import gridfs -import argparse -import logging -from pymongo import MongoClient - -from pysteps.utils import transformation -from pysteps import extrapolation - -from urllib.parse import quote_plus -import os -import sys - -WAR_THRESHOLD = 0.05 # Select only fields with rain for analysis - - -def is_valid_iso8601(time_str: str) -> bool: - """Check if the given string is a valid ISO 8601 datetime.""" - try: - datetime.datetime.fromisoformat(time_str) - return True - except ValueError: - return False - - -def lagr_auto_cor(data: np.ndarray, oflow: np.ndarray, config: dict): - """ - Generate the Lagrangian auto correlations for STEPS cascades. - - Args: - data (np.ndarray): [T, L, M, N] where: - - T = ar_order + 1 (number of time steps) - - L = number of cascade levels - - M, N = spatial dimensions. - oflow (np.ndarray): [2, M, N] Optical flow vectors. - config (dict): Configuration dictionary containing: - - "n_cascade_levels": Number of cascade levels (L). - - "ar_order": Autoregressive order (1 or 2). - - "extrapolation_method": Method for extrapolating fields. - - Returns: - np.ndarray: Autocorrelation coefficients of shape (L, ar_order). - """ - - n_cascade_levels = config["pysteps"]["n_cascade_levels"] - ar_order = config["pysteps"]["ar_order"] - e_method = config["pysteps"]["extrapolation_method"] - - if data.shape[0] < (ar_order + 1): - raise ValueError( - f"Insufficient time steps. Expected at least {ar_order + 1}, got {data.shape[0]}.") - - extrapolation_method = extrapolation.get_method(e_method) - autocorrelation_coefficients = np.full( - (n_cascade_levels, ar_order), np.nan) - - for level in range(n_cascade_levels): - lag_1 = extrapolation_method(data[-2, level], oflow, 1)[0] - lag_1 = np.where(np.isfinite(lag_1), lag_1, 0) - - data_t = np.where(np.isfinite(data[-1, level]), data[-1, level], 0) - if np.std(lag_1) > 1e-1 and np.std(data_t) > 1e-1: - autocorrelation_coefficients[level, 0] = np.corrcoef( - lag_1.flatten(), data_t.flatten())[0, 1] - - if ar_order == 2: - lag_2 = extrapolation_method(data[-3, level], oflow, 1)[0] - lag_2 = np.where(np.isfinite(lag_2), lag_2, 0) - - lag_1 = extrapolation_method(lag_2, oflow, 1)[0] - lag_1 = np.where(np.isfinite(lag_1), lag_1, 0) - - if np.std(lag_1) > 1e-1 and np.std(data_t) > 1e-1: - autocorrelation_coefficients[level, 1] = np.corrcoef( - lag_1.flatten(), data_t.flatten())[0, 1] - - return autocorrelation_coefficients - - -def process_files(file_names, db, config: dict): - """ - Loop over a list of files and calculate the STEPS parameters. - - Args: - file_names (list[str]): List of files to process - data_base (pymongo.MongoClient): MongoDB database - config (dict): Dictionary with pysteps configuration - - Returns: - list[dict]: List of steps parameter dictionaries - """ - ar_order = config["pysteps"]["ar_order"] - timestep = config["pysteps"]["timestep"] - time_step_mins = config["pysteps"]["timestep"] // 60 - db_zerovalue = config["pysteps"]["zerovalue"] - db_threshold = config["pysteps"]["threshold"] - scale_break = config['pysteps']["scale_break"] - kmperpixel = config['pysteps']["kmperpixel"] - name = config['name'] - - delta_time_step = datetime.timedelta(seconds=timestep) - - rain_col_name = f"{name}.rain" - rain_fs = gridfs.GridFS(db, collection=rain_col_name) - - params = [] - lag_2, lag_1, lag_0 = None, None, None - oflow_0 = None - - for file_name in file_names: - field = rain_fs.find_one({"filename": file_name}) - if field is None: - continue - - # Set up the field metadata - valid_time = field.metadata["valid_time"] - valid_time = valid_time.replace(tzinfo=datetime.timezone.utc) - base_time = field.metadata["base_time"] - if base_time is not None: - base_time = base_time.replace(tzinfo=datetime.timezone.utc) - - ensemble = field.metadata["ensemble"] - product = field.metadata["product"] - metadata = { - "field_id": field._id, - "product": product, - "base_time": base_time, - "ensemble": ensemble, - "valid_time": valid_time, - "kmperpixel": kmperpixel # Need this when generating the stochastic fields - } - - # Read in the rain field - in_buffer = field.read() - rain_geodata, _, rain_data = read_nc(in_buffer) - # Needs to be consistent with db_threshold = -10 - rain_geodata["threshold"] = 0.1 - rain_geodata["zerovalue"] = 0 - rain_stats = compute_field_stats(rain_data, rain_geodata) - - if rain_geodata["transform"] is None: - db_data, db_geodata = transformation.dB_transform( - rain_data, rain_geodata, threshold=0.1, zerovalue=db_zerovalue - ) - db_data[~np.isfinite(db_data)] = db_geodata["zerovalue"] - else: - db_data = rain_data.copy() - - db_geodata["threshold"] = db_threshold - db_geodata["zerovalue"] = db_zerovalue - db_stats = compute_field_stats(db_data, db_geodata) - - # Compute the power spectrum and prob dist - steps_params = compute_field_parameters( - db_data, db_geodata, scale_break) - - # Read in the cascades and calculate Lagr auto correlation - cascade = {} - - # Only calculate cascade parameters if enough rain - if rain_stats["nonzero_fraction"] < 0.05: - steps_params.update({ - "metadata": metadata, - "rain_stats": rain_stats, - "dbr_stats": db_stats, - "cascade": None - }) - params.append(steps_params) - continue - - # Fetch states for (t, t-1, t-2) - query = { - "metadata.product": product, - "metadata.valid_time": {"$in": [valid_time, valid_time - delta_time_step, valid_time - 2 * delta_time_step]}, - "metadata.base_time": base_time, - "metadata.ensemble": ensemble - } - - states = get_states(db, name, query) - - # Set up the keys for the states - bkey = "NA" if base_time is None else base_time - ekey = "NA" if ensemble is None else ensemble - lag0_inx = (valid_time, bkey, ekey) - lag1_inx = (valid_time - delta_time_step, bkey, ekey) - lag2_inx = (valid_time - 2*delta_time_step, bkey, ekey) - - state = states.get(lag0_inx) - lag_0 = state["cascade"] if state is not None else None - oflow_0 = state["optical_flow"] if state is not None else None - state = states.get(lag1_inx) - lag_1 = state["cascade"] if state is not None else None - state = states.get(lag2_inx) - lag_2 = state["cascade"] if state is not None else None - - # set up the cascade level means and stds for valid_time - stds = lag_0.get("stds") if lag_0 else None - cascade["stds"] = stds - means = lag_0.get("means") if lag_0 else None - cascade["means"] = means - - # Calculate the Lagr auto correl if enough data - num_valid = sum(x is not None for x in [lag_2, lag_1, lag_0]) - if num_valid == ar_order + 1 and oflow_0 is not None: - data = np.array([lag_1["cascade_levels"], lag_0["cascade_levels"]]) if ar_order == 1 else np.array( - [lag_2["cascade_levels"], lag_1["cascade_levels"], lag_0["cascade_levels"]]) - auto_cor = lagr_auto_cor(data, oflow_0, config) - - # calculate the correlation lengths (minutes) - lag1_list = auto_cor[:, 0].tolist() - lag2_list = auto_cor[:, 1].tolist() if ar_order == 2 else [ - None] * len(lag1_list) - corl_list = [ - correlation_length(l1, l2, time_step_mins) - for l1, l2 in zip(lag1_list, lag2_list) - ] - - cascade.update({ - "lag1": lag1_list, - "lag2": lag2_list, - "corl": corl_list, - "corl_zero":corl_list[0] - }) - else: - cascade.update({ - "lag1": None, - "lag2": None, - "corl": None, - "corl_zero":None - }) - - steps_params.update({ - "metadata": metadata, - "rain_stats": rain_stats, - "dbr_stats": db_stats, - "cascade": cascade - }) - params.append(steps_params) - - return params - - -def main(): - - parser = argparse.ArgumentParser( - description="Calculate STEPS parameters") - - parser.add_argument('-s', '--start', type=str, required=True, - help='Start time yyyy-mm-ddTHH:MM:SS') - parser.add_argument('-e', '--end', type=str, required=True, - help='End time yyyy-mm-ddTHH:MM:SS') - parser.add_argument('-n', '--name', type=str, required=True, - help='Name of domain [AKL]') - parser.add_argument('-p', '--product', type=str, required=True, - help='Name of input product [QPE, auckprec, qpesim]') - - args = parser.parse_args() - - # Include app name (module name) in log output - logging.basicConfig( - level=logging.INFO, - format='%(asctime)s [%(levelname)s] %(name)s: %(message)s', - stream=sys.stdout - ) - - logger = logging.getLogger(__name__) - logger.info("Calculating STEPS parameters") - - # Validate start and end time and read them in - if args.start and is_valid_iso8601(args.start): - start_time = datetime.datetime.fromisoformat(str(args.start)) - if start_time.tzinfo is None: - start_time = start_time.replace(tzinfo=datetime.timezone.utc) - else: - logging.error( - "Invalid start time format. Please provide a valid ISO 8601 time string.") - return - - if args.end and is_valid_iso8601(args.end): - end_time = datetime.datetime.fromisoformat(str(args.end)) - if end_time.tzinfo is None: - end_time = end_time.replace(tzinfo=datetime.timezone.utc) - else: - logging.error( - "Invalid start time format. Please provide a valid ISO 8601 time string.") - return - - name = str(args.name) - product = str(args.product) - if product not in ["QPE", "auckprec", "qpesim"]: - logging.error( - "Invalid product. Please provide either 'QPE', 'auckprec', or 'qpesim'.") - return - - db = get_db() - config_coll = db["config"] - record = config_coll.find_one({'config.name': name}, sort=[ - ('time', pymongo.DESCENDING)]) - if record is None: - logging.error(f"Could not find configuration for domain {name}") - return - - config = record['config'] - meta_coll = db[f"{name}.rain.files"] - params_coll = db[f"{name}.params"] - - # Single pass through the data for qpe product - if product == "QPE": - f_filter = { - "metadata.product": product, - "metadata.valid_time": {"$gte": start_time, "$lte": end_time}, - "metadata.wetted_area_ratio": {"$gte": WAR_THRESHOLD} - } - - fields = {"_id": 0, "filename": 1, "metadata.wetted_area_ratio": 1} - results = meta_coll.find(filter=f_filter, projection=fields).sort( - "filename", pymongo.ASCENDING) - if results is None: - logging.error( - f"Failed to find {product}data for {start_time} - {end_time}") - return - - file_names = [doc["filename"] for doc in results] - logging.info( - f"Found {len(file_names)} {product} fields to process between {start_time} and {end_time}") - steps_params = process_files(file_names, db, config) - - params_coll.delete_many({ - "metadata.product": product, - "metadata.valid_time": {"$gte": start_time, "$lte": end_time} - }) - if steps_params: - params_coll.insert_many(steps_params) - else: - # Get the list of unique nwp run times in this period in ascending order - start_base_time = start_time - datetime.timedelta(hours=12) - base_time_query = { - "metadata.product": product, - "metadata.base_time": {"$gte": start_base_time, "$lte": end_time} - } - base_times = meta_coll.distinct("metadata.base_time", base_time_query) - if base_times is None: - logging.error( - f"Failed to find {product} data for {start_time} - {end_time}") - return - - base_times.sort() - - for base_time in base_times: - - # Get the list of unique ensembles found at base_time - ensembles = meta_coll.distinct( - "metadata.ensemble", {"metadata.product": product, "metadata.base_time": base_time}) - - if not ensembles: - logging.warning( - f"No ensembles found for base_time {base_time}") - continue # Skip this base_time if no ensembles exist - - logging.info( - f"Found {len(ensembles)} ensembles for base_time {base_time}") - ensembles.sort() - - for ensemble in ensembles: - # Get all the forecasts for this base_time and ensemble and process - - f_filter = { - "metadata.product": product, - "metadata.valid_time": {"$gte": start_time, "$lte": end_time}, - "metadata.base_time": base_time, - "metadata.ensemble": ensemble, - "metadata.wetted_area_ratio": {"$gte": WAR_THRESHOLD} - } - - fields = {"_id": 0, "filename": 1, - "metadata.wetted_area_ratio": 1} - results = meta_coll.find(filter=f_filter, projection=fields).sort( - "filename", pymongo.ASCENDING) - file_names = [doc["filename"] for doc in results] - - if len(file_names) > 0: - steps_params = process_files( - file_names, db, config) - params_coll.delete_many({ - "metadata.product": product, - "metadata.valid_time": {"$gte": start_time, "$lte": end_time}, - "metadata.base_time": base_time, - "metadata.ensemble": ensemble, - }) - if steps_params: - params_coll.insert_many(steps_params) - - -if __name__ == "__main__": - main() diff --git a/pysteps/param/nc_utils.py b/pysteps/param/nc_utils.py new file mode 100644 index 000000000..17a304172 --- /dev/null +++ b/pysteps/param/nc_utils.py @@ -0,0 +1,107 @@ +import xarray as xr +import pandas as pd +import numpy as np +import datetime +import logging + + +def generate_geo_dict(domain): + ncols = domain.get("n_cols") + nrows = domain.get("n_rows") + psize = domain.get("p_size") + start_x = domain.get("start_x") + start_y = domain.get("start_y") + x = [start_x + i * psize for i in range(ncols)] + y = [start_y + i * psize for i in range(nrows)] + + out_geo = {} + out_geo["x"] = x + out_geo["y"] = y + out_geo["xpixelsize"] = psize + out_geo["ypixelsize"] = psize + out_geo["x1"] = start_x + out_geo["y1"] = start_y + out_geo["x2"] = start_x + (ncols - 1) * psize + out_geo["y2"] = start_y + (nrows - 1) * psize + out_geo["projection"] = domain["projection"]["epsg"] + out_geo["cartesian_unit"] = "m" + out_geo["yorigin"] = "lower" + out_geo["unit"] = "mm/h" + out_geo["threshold"] = 0 + out_geo["transform"] = None + + return out_geo + + +def generate_geo_dict_xy(x: np.ndarray, y: np.ndarray, epsg: str): + n_cols = x.size + n_rows = y.size + + out_geo = {} + out_geo["xpixelsize"] = (x[-1] - x[0]) / (n_cols - 1) + out_geo["ypixelsize"] = (y[-1] - y[0]) / (n_rows - 1) + out_geo["x1"] = x[0] + out_geo["x2"] = x[-1] + out_geo["y1"] = y[0] + out_geo["y2"] = y[-1] + out_geo["projection"] = epsg + out_geo["cartesian_unit"] = "m" + out_geo["yorigin"] = "lower" + out_geo["unit"] = "mm/h" + out_geo["threshold"] = 0 + out_geo["transform"] = None + + return out_geo + + +def read_qpe_netcdf(file_path): + """ + Read WRNZ QPE NetCDF file and return xarray Dataset of rain rate with: + - 'rain' variable in [time, yc, xc] order + - time as timezone-aware UTC datetimes + - EPSG:2193 (NZTM2000) projection info added using CF conventions + - Return None on error reading the file + + Assumes that the input file is rain rate in [t,y,x] order + """ + + try: + ds = xr.open_dataset(file_path, decode_times=True) + ds.load() + + # Make the times timezone-aware UTC + time_values = ds["time"].values.astype("datetime64[ns]") + time_utc = pd.DatetimeIndex(time_values, tz=datetime.UTC) + ds["time"] = ("time", time_utc) + + # Rename + ds = ds.rename({"rainfall": "rain"}) + + # Define CF-compliant grid mapping for EPSG:2193 + crs = xr.DataArray( + 0, + attrs={ + "grid_mapping_name": "transverse_mercator", + "scale_factor_at_central_meridian": 0.9996, + "longitude_of_central_meridian": 173.0, + "latitude_of_projection_origin": 0.0, + "false_easting": 1600000.0, + "false_northing": 10000000.0, + "semi_major_axis": 6378137.0, + "inverse_flattening": 298.257222101, + "spatial_ref": "EPSG:2193", + }, + name="NZTM2000", + ) + + ds["NZTM2000"] = crs + ds["rain"].attrs["grid_mapping"] = "NZTM2000" + + ds = ds[["rain", "NZTM2000"]] + ds = ds.assign_coords(time=ds["time"], yc=ds["y"], xc=ds["x"]) + + return ds + + except (ValueError, OverflowError, TypeError) as e: + logging.warning(f"Failed to read {file_path}: {e}") + return None diff --git a/pysteps/param/nwp_param_qc.py b/pysteps/param/nwp_param_qc.py deleted file mode 100644 index 6187d250c..000000000 --- a/pysteps/param/nwp_param_qc.py +++ /dev/null @@ -1,199 +0,0 @@ -import argparse -import datetime -import logging -import numpy as np -import pandas as pd -from pymongo import UpdateOne -from models.mongo_access import get_db, get_config -from models.steps_params import power_law_acor, StochasticRainParameters - -from statsmodels.tsa.api import SimpleExpSmoothing -import pymongo.collection -from typing import Dict -logging.basicConfig(level=logging.INFO) - - -def get_parameters_df(query: Dict, param_coll: pymongo.collection.Collection) -> pd.DataFrame: - """ - Retrieve STEPS parameters from the database and return a DataFrame - indexed by (valid_time, base_time, ensemble), using 'NA' as sentinel for missing values. - - Args: - query (dict): MongoDB query dictionary. - param_coll (pymongo.collection.Collection): MongoDB collection. - - Returns: - pd.DataFrame: Indexed by (valid_time, base_time, ensemble), with a 'param' column. - """ - records = [] - - for doc in param_coll.find(query).sort("metadata.valid_time", pymongo.ASCENDING): - try: - metadata = doc.get("metadata", {}) - if metadata is None: - continue - - if doc["cascade"]["lag1"] is None or doc["cascade"]["lag2"] is None: - continue - - valid_time = metadata.get("valid_time") - valid_time = pd.to_datetime(valid_time,utc=True) - - base_time = metadata.get("base_time") - if base_time is None: - base_time = pd.NaT - else: - base_time = pd.to_datetime(base_time, utc=True) - - ensemble = metadata.get("ensemble") - - param = StochasticRainParameters.from_dict(doc) - - param.calc_corl() - records.append({ - "valid_time": valid_time, - "base_time": base_time, - "ensemble": ensemble, - "param": param - }) - except Exception as e: - print( - f"Warning: could not parse parameter for {metadata.get('valid_time')}: {e}") - - if not records: - return pd.DataFrame(columns=["valid_time", "base_time", "ensemble", "param"]) - - df = pd.DataFrame(records) - return df - -def parse_args(): - parser = argparse.ArgumentParser( - description="QC and update NWP lag autocorrelations") - parser.add_argument("-n", "--name", required=True, help="Domain name, e.g., AKL") - parser.add_argument("-p","--product", required=True, - help="Product name, e.g., auckprec") - parser.add_argument("-b","--base_time", required=True, - help="Base time, ISO format UTC (e.g., 2023-01-26T03:00:00)") - parser.add_argument("--dry_run", action="store_true", - help="Run without writing to database") - return parser.parse_args() - - -def qc_update_autocorrelations(dry_run: bool, name: str, product: str, base_time: datetime.datetime): - db = get_db() - config = get_config(db, name) - dt = datetime.timedelta(seconds=config["pysteps"]["timestep"]) - dt_seconds = dt.total_seconds() - - corl_pvals = config["dynamic_scaling"]["cor_len_pvals"] - corl_max = max(corl_pvals) - corl_min = min(corl_pvals) - - query = { - "metadata.product": product, - "metadata.base_time": base_time, - } - - param_coll = db[f"{name}.params"] - df = get_parameters_df(query, param_coll) - if df.empty: - logging.warning("No parameters found for the given base_time.") - return - - # Build corl_0 time series - records = [] - - ensembles = df["ensemble"].unique() - ensembles = np.sort(ensembles) - valid_times = df["valid_time"].unique() - t_min = min(valid_times) - t_max = max(valid_times) - all_times = pd.date_range(start=t_min, end=t_max, freq=dt, tz="UTC") - - # Convert the base_time to datetime64 for working with dataframe - vbase_time = pd.NaT - if base_time is not None: - vbase_time = pd.to_datetime(base_time,utc=True) - - for ens in ensembles: - ens_df = df.loc[ (df['base_time'] == vbase_time) & - (df['ensemble'] == ens),["valid_time", "param"] ].set_index("valid_time") - if ens_df.empty: - continue - - for vt in all_times: - try: - param = ens_df.loc[ vt,"param"] - corl_0 = param.corl_zero - - # Threshold at the 5 and 95 percentile values - corl_0 = corl_min if corl_0 < corl_min else corl_0 - corl_0 = corl_max if corl_0 > corl_max else corl_0 - except KeyError: - corl_0 = np.nan - - records.append({ - "valid_time": vt, - "ensemble": ens, - "corl_0": corl_0 - }) - - corl_df = pd.DataFrame.from_records(records) - corl_df = corl_df.sort_values(["ensemble", "valid_time"]) - updates = [] - - for ens in ensembles: - ens_df = corl_df[corl_df["ensemble"] == ens].set_index("valid_time") - if ens_df["corl_0"].isnull().all(): - logging.info(f"No valid corl_0 values for ensemble {ens}, skipping.") - continue - - mean_corl = ens_df["corl_0"].mean() - ens_df["corl_0"] = ens_df["corl_0"].fillna(mean_corl) - ens_df.index.freq = pd.Timedelta(seconds=dt_seconds) - - # Apply smoothing - model = SimpleExpSmoothing(ens_df["corl_0"], initialization_method="estimated").fit( - smoothing_level=0.2, optimized=False) - ens_df["corl_0_smoothed"] = model.fittedvalues - - for vt in ens_df.index: - T_ref = ens_df.loc[vt, "corl_0_smoothed"] - lags, corl = power_law_acor(config, T_ref) - valid_time = vt.to_pydatetime() - - updates.append(UpdateOne( - { - "metadata.product": product, - "metadata.valid_time": valid_time, - "metadata.base_time": base_time, - "metadata.ensemble": int(ens) - }, - { - "$set": { - "cascade.lag1": [float(x) for x in lags[:, 0]], - "cascade.lag2": [float(x) for x in lags[:, 1]], - "cascade.corl": [float(x) for x in corl], - "cascade.corl_zero":float(corl[0]) - } - }, - upsert=False - )) - if updates: - if dry_run: - logging.info( - f"{len(updates)} updates prepared (dry run, not written)") - else: - result = param_coll.bulk_write(updates) - logging.info(f"Updated {result.modified_count} documents.") - else: - logging.info("No documents to update.") - - -if __name__ == "__main__": - args = parse_args() - dry_run = args.dry_run - base_time = datetime.datetime.fromisoformat( - args.base_time).replace(tzinfo=datetime.timezone.utc) - qc_update_autocorrelations( - dry_run, args.name, args.product, base_time) diff --git a/pysteps/param/pysteps_param.py b/pysteps/param/pysteps_param.py deleted file mode 100644 index 8e09486ea..000000000 --- a/pysteps/param/pysteps_param.py +++ /dev/null @@ -1,369 +0,0 @@ -from typing import List, Callable -import argparse -import logging -import datetime -import numpy as np -import copy -import os -import sys -import pandas as pd -from cascade.bandpass_filters import filter_gaussian -from utils import transformation -from cascade.decomposition import decomposition_fft -from mongo.nc_utils import generate_geo_data, make_nc_name_dt, write_netcdf -from mongo.gridfs_io import get_states, load_rain_field -from mongo.mongo_access import get_base_time, get_parameters_df -from steps_params import StochasticRainParameters, blend_parameters -from shared_utils import initialize_config -from shared_utils import zero_state, update_field - - -def get_weight(lag): - width = 3 * 3600 - weight = np.exp(-(lag/width)**2) - return weight - - -def main(): - - parser = argparse.ArgumentParser(description="Run nwpblend forecasts") - parser.add_argument('-b', '--base_time', required=True, - help='Base time in ISO 8601 format') - parser.add_argument('-n', '--name', required=True, - help='Domain name (e.g., AKL)') - args = parser.parse_args() - - - # Include app name (module name) in log output - logging.basicConfig( - level=logging.INFO, - format='%(asctime)s [%(levelname)s] %(name)s: %(message)s', - stream=sys.stdout - ) - - logger = logging.getLogger(__name__) - logger.info("Gemerating nwpblend ensembles") - - name = args.name - db, config, out_base_time = initialize_config(args.base_time, name) - - param_coll = db[f"{name}.params"] - meta_coll = db[f"{name}.rain.files"] - - time_step_seconds = config['pysteps']['timestep'] - time_step = datetime.timedelta(seconds=time_step_seconds) - ar_order = config['pysteps']['ar_order'] - n_levels = config['pysteps']['n_cascade_levels'] - db_threshold = config['pysteps']['threshold'] - scale_break = config['pysteps']['scale_break'] - - # Set up the georeferencing data for the output forecasts - domain = config['domain'] - start_x = domain['start_x'] - start_y = domain['start_y'] - p_size = domain['p_size'] - n_rows = domain['n_rows'] - n_cols = domain['n_cols'] - x = [start_x + i * p_size for i in range(n_cols)] - y = [start_y + i * p_size for i in range(n_rows)] - geo_data = generate_geo_data(x, y) - geo_data["projection"] = config['projection']["epsg"] - - # Set up the bandpass filter - p_size_km = p_size / 1000.0 - bp_filter = filter_gaussian((n_rows, n_cols), n_levels, d=p_size_km) - - # Configure the output product - out_product = "nwpblend" - out_config = config['output'][out_product] - n_ens = out_config.get('n_ens_members', 10) - n_forecasts = out_config.get('n_forecasts', 12) - rad_product = out_config.get('rad_product', None) - nwp_product = out_config.get('nwp_product', None) - gridfs_out = out_config.get('gridfs_out', False) - nc_out = out_config.get('nc_out', False) - out_dir_name = out_config.get('out_dir_name', None) - out_file_name = out_config.get( - 'out_file_name', "$N_$P_$V{%Y-%m-%dT%H:%M:%S}_$B{%Y-%m-%dT%H:%M:%S}_$E.nc") - - # Validate the output configuration details - if rad_product is None: - logging.error(f"Radar product not specified") - return - if n_ens < 1: - logging.error(f"Invalid number of ensemble members: {n_ens}") - return - if n_forecasts < 1: - logging.error(f"Invalid number of lead times: {n_forecasts}") - return - if not gridfs_out and not nc_out: - logging.error( - "No output format specified. Please set either gridfs_out or nc_out to True.") - return - - if nc_out: - if out_dir_name is None: - logging.error(f"No output directory name found") - return - - logging.info(f"Generating nwpblend for {out_base_time}") - - # Make the list of output forecast times - forecast_times = [out_base_time + ia * - time_step for ia in range(0, n_forecasts+1)] - - - # Get the initial state(s) for the input radar field at this base time - # Set any missing states to None - base_time_key = "NA" - ensemble_key = "NA" - init_times = [out_base_time] - if ar_order == 2: - init_times = [out_base_time - time_step, out_base_time] - - query = { - "metadata.product": rad_product, - "metadata.valid_time": {"$in": init_times} - } - init_state = get_states(db, name, query, - get_cascade=True, get_optical_flow=True) - rad_params_df = get_parameters_df(query, param_coll) - - for vtime in init_times: - key = (vtime, base_time_key, ensemble_key) - if key not in init_state: - init_state[key] = zero_state(config) - logging.debug(f"Found missing QPE oflow for {vtime}") - - # Check if row exists for this combination - mask = ( - (rad_params_df["valid_time"] == vtime) & - (rad_params_df["base_time"] == "NA") & - (rad_params_df["ensemble"] == "NA") - ) - - if rad_params_df[mask].empty: - logging.debug(f"Found missing QPE parametersfor {vtime}") - - def_param = StochasticRainParameters() - def_param.calc_acor(config) - def_param.kmperpixel = p_size_km - def_param.scale_break = scale_break - def_param.threshold = db_threshold - - new_row = { - "valid_time": vtime, - "base_time": "NA", - "ensemble": "NA", - "param": def_param - } - rad_params_df = pd.concat([rad_params_df, pd.DataFrame([new_row])], ignore_index=True) - - # Get the base_time for the nwp run nearest to the output base_time - nwp_base_time = get_base_time(out_base_time, nwp_product, name, db) - - # Get the list of ensemble members for this nwp_base_time - query = { - "metadata.product": nwp_product, - "metadata.base_time": nwp_base_time} - nwp_ensembles = meta_coll.distinct("metadata.ensemble", query) - if nwp_ensembles is None: - logging.warning( - f"Failed to find ensembles for {nwp_product} data for {out_base_time}") - nwp_ensembles.sort() - n_nwp_ens = len(nwp_ensembles) - - # Get the NWP parameters and optical flows for the NWP ensemble - query = { - "metadata.product": nwp_product, - "metadata.valid_time": {"$in": forecast_times}, - 'metadata.base_time': nwp_base_time - } - nwp_params_df = get_parameters_df(query, param_coll) - nwp_oflows = get_states( - db, name, query, get_cascade=False, get_optical_flow=True) - - # Start the loop over the ensemble members - for iens in range(n_ens): - - # Calculate the set of blended parameters for this output ensemble - # Get the radar parameter - qpe_rows = rad_params_df[ - (rad_params_df["valid_time"] == out_base_time) & - (rad_params_df["base_time"] == "NA") & - (rad_params_df["ensemble"] == "NA") - ] - rad_param = qpe_rows.iloc[0]["param"] - - # Randomly select an ensemble member from the NWP - nwp_ens = np.random.randint(low=0, high=n_nwp_ens) - nwp_ensemble_df = nwp_params_df[ - (nwp_params_df["base_time"] == nwp_base_time) & - (nwp_params_df["ensemble"] == nwp_ens) - ][["valid_time", "param"]].copy() - nwp_ensemble_df["valid_time"] = pd.to_datetime(nwp_ensemble_df["valid_time"]) - nwp_ensemble_df.set_index("valid_time", inplace=True) - nwp_ensemble_df = nwp_ensemble_df.sort_index() - - # Fill in any missing forecast times with default parameters - for vtime in forecast_times: - if vtime not in nwp_ensemble_df.index: - def_param = StochasticRainParameters() - def_param.calc_acor(config) - def_param.kmperpixel = p_size_km - def_param.scale_break = scale_break - def_param.threshold = db_threshold - nwp_ensemble_df.loc[vtime,"param"] = def_param - - # Blend the parameters - blend_params_df = blend_parameters(config, out_base_time, nwp_ensemble_df, rad_param) - - # Set up the initial conditions for the forecast loop - # The order is [t-1, t0] in init_times for AR(2) - if ar_order == 1: - key = (init_times[0], "NA", "NA") - state = init_state.get(key) - - if state is not None: - cascade = state.get("cascade") - optical = state.get("optical_flow") - fx_cascades = [copy.deepcopy(cascade)] if cascade is not None else [None] - fx_oflow = copy.deepcopy(optical) if optical is not None else None - else: - fx_cascades = [None] - fx_oflow = None - - else: # AR(2) - key_0 = (init_times[0], "NA", "NA") - key_1 = (init_times[1], "NA", "NA") - - state_0 = init_state.get(key_0) - state_1 = init_state.get(key_1) - - if state_0 is not None and state_1 is not None: - casc_0 = state_0.get("cascade") - casc_1 = state_1.get("cascade") - optical = state_1.get("optical_flow") - - fx_cascades = [ - copy.deepcopy(casc_0) if casc_0 is not None else None, - copy.deepcopy(casc_1) if casc_1 is not None else None - ] - fx_oflow = copy.deepcopy(optical) if optical is not None else None - else: - fx_cascades = [None, None] - fx_oflow = None - - # Start the forecast loop - for ifx in range(1, n_forecasts+1): - valid_time = forecast_times[ifx] - fx_param = blend_params_df.loc[valid_time, "param"] - - fx_dbrain = update_field( - fx_cascades, fx_oflow, fx_param, bp_filter, config) - has_nan = np.isnan(fx_dbrain).any() if fx_dbrain is not None else True - - if has_nan : - fx_rain = np.zeros((n_rows, n_cols)) - else: - fx_rain, _ = transformation.dB_transform( - fx_dbrain, inverse=True, threshold=db_threshold, zerovalue=0) - - # Make the output file name - fx_file_name = make_nc_name_dt( - out_file_name, name, out_product, valid_time, out_base_time, iens) - - # Write the NetCDF data to a memoryview buffer - # This is an ugly hack on time zones - vtime = valid_time - if vtime.tzinfo is None: - vtime = vtime.replace(tzinfo=datetime.timezone.utc) - btime = out_base_time - if btime.tzinfo is None: - btime = btime.replace(tzinfo=datetime.timezone.utc) - vtime_stamp = vtime.timestamp() - - nc_buf = write_netcdf(fx_rain, geo_data, vtime_stamp) - - if gridfs_out: - # Create metadata - rain_mask = fx_rain.copy() - rain_mask[rain_mask < 1] = 0 - rain_mask[rain_mask > 0] = 1 - war = rain_mask.sum() / (n_cols * n_rows) - mean = np.nanmean(fx_rain) - std_dev = np.nanstd(fx_rain) - max = np.nanmax(fx_rain) - metadata = { - "product": out_product, - "domain": name, - "ensemble": int(iens), - "base_time": btime, - "valid_time": vtime, - "mean": float(mean), - "wetted_area_ratio": float(war), - "std_dev": float(std_dev), - "max": float(max), - "forecast_lead_time": int(ifx*time_step_seconds) - } - load_rain_field(db, name, fx_file_name, nc_buf, metadata) - - if nc_out: - fx_dir_name = make_nc_name_dt( - out_dir_name, name, out_product, valid_time, out_base_time, iens) - if not os.path.exists(fx_dir_name): - os.makedirs(fx_dir_name) - fx_file_path = os.path.join(fx_dir_name, fx_file_name) - with open(fx_file_path, 'wb') as f: - f.write(nc_buf.tobytes()) - - # Update the cascade state list for the next forecast step - if ar_order == 2: - # Push the cascade history down (t0 → t-1) - fx_cascades[0] = copy.deepcopy(fx_cascades[1]) - - # Update the latest cascade (t0) from current forecast brain - if fx_dbrain is not None: - if has_nan: - fx_cascades[1] = zero_state(config)["cascade"] - logging.warning(f"NaNs found for {valid_time}, {iens} ") - else: - fx_cascades[1] = decomposition_fft( - fx_dbrain, bp_filter, compute_stats=True, normalize=True - ) - else: - fx_cascades[1] = zero_state(config)["cascade"] - - elif ar_order == 1: - # Only update the current cascade - if fx_dbrain is not None: - if has_nan: - fx_cascades[0] = zero_state(config)["cascade"] - logging.warning(f"NaNs found for {valid_time}, {iens} ") - else: - fx_cascades[0] = decomposition_fft( - fx_dbrain, bp_filter, compute_stats=True, normalize=True - ) - else: - fx_cascades[0] = zero_state(config)["cascade"] - - # Update the optical flow field using radar–NWP blending - if ifx < n_forecasts: - rad_key = (out_base_time, "NA", "NA") - nwp_key = (out_base_time, nwp_base_time, nwp_ens) - - lag = (valid_time - out_base_time).total_seconds() - weight = get_weight(lag) - - # Check availability of both radar and NWP optical flows - rad_oflow = init_state.get(rad_key, {}).get("optical_flow") - nwp_oflow_entry = nwp_oflows.get(nwp_key) - nwp_oflow = nwp_oflow_entry.get("optical_flow") if nwp_oflow_entry else None - - if rad_oflow is not None and nwp_oflow is not None: - fx_oflow = weight * rad_oflow + (1 - weight) * nwp_oflow - else: - fx_oflow = None - -if __name__ == "__main__": - main() diff --git a/pysteps/param/rainfield_stats.py b/pysteps/param/rainfield_stats.py new file mode 100644 index 000000000..a8486b703 --- /dev/null +++ b/pysteps/param/rainfield_stats.py @@ -0,0 +1,509 @@ +# Contains: RainfieldStats (dataclass, from_dict, to_dict), compute_field_parameters, compute_field_stats +""" +Functions to calculate rainfield statistics +""" +from typing import Optional, Tuple, Dict, List, Any +import datetime +from dataclasses import dataclass +import xarray as xr +from scipy.optimize import curve_fit +import numpy as np + +MAX_RAIN_RATE = 250 +N_BINS = 200 + + +@dataclass +class RainfieldStats: + domain: Optional[str] = None + product: Optional[str] = None + valid_time: Optional[datetime.datetime] = None + base_time: Optional[datetime.datetime] = None + ensemble: Optional[int] = None + filename: Optional[str] = None + + transform: Optional[str] = None + zerovalue: Optional[float] = None + threshold: Optional[float] = None + kmperpixel: Optional[float] = None + + mean_db: Optional[float] = None + stdev_db: Optional[float] = None + nonzero_mean_db: Optional[float] = None + nonzero_stdev_db: Optional[float] = None + + mean_rain: Optional[float] = None + stdev_rain: Optional[float] = None + rain_fraction: Optional[float] = None + nonzero_mean_rain: Optional[float] = None + nonzero_stdev_rain: Optional[float] = None + + psd: Optional[List[float]] = None + psd_bins: Optional[List[float]] = None + c1: Optional[float] = None + c2: Optional[float] = None + scale_break: Optional[float] = None + beta_1: Optional[float] = None + beta_2: Optional[float] = None + corl_zero: Optional[float] = None + + cdf: Optional[List[float]] = None + cdf_bins: Optional[List[float]] = None + + cascade_stds: Optional[List[float]] = None + cascade_means: Optional[List[float]] = None + cascade_lag1: Optional[List[float]] = None + cascade_lag2: Optional[List[float]] = None + cascade_corl: Optional[List[float]] = None + + def get(self, key: str, default: Any = None) -> Any: + """Mimic dict.get() for selected attributes.""" + return getattr(self, key, default) + + def calc_corl(self): + """Populate the correlation lengths using lag1 and lag2 values.""" + + # Make sure that we have defined auto-correlations + if self.cascade_lag1 is None or self.cascade_lag2 is None: + self.cascade_corl = None + return + + # We have defined auto-correlations so set up the correlation length array + n_levels = len(self.cascade_lag1) + self.cascade_corl = [np.nan] * n_levels + for ilev in range(n_levels): + lag1 = self.cascade_lag1[ilev] + lag2 = self.cascade_lag2[ilev] + self.cascade_corl[ilev] = correlation_length(lag1, lag2) + self.corl_zero = self.cascade_corl[0] + + def calc_acor(self, config) -> None: + T_ref = self.corl_zero + if T_ref is None or np.isnan(T_ref): + T_ref = config["dynamic_scaling"]["cor_len_pvals"][1] + + acor, corl = power_law_acor(config, T_ref) + self.cascade_corl = [float(x) for x in corl] + self.cascade_lag1 = [float(x) for x in acor[:, 0]] + self.cascade_lag2 = [float(x) for x in acor[:, 1]] + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "RainfieldStats": + dbr = data.get("dbr_stats", {}) + rain = data.get("rain_stats", {}) + pspec = data.get("power_spectrum", {}) + model = pspec.get("model", {}) if pspec else {} + cdf_data = data.get("cdf", {}) + cascade = data.get("cascade", {}) + meta = data.get("metadata", {}) + + kwargs = { + "product": meta.get("product"), + "valid_time": meta.get("valid_time"), + "base_time": meta.get("base_time"), + "ensemble": meta.get("ensemble"), + "filename": meta.get("filename"), + "kmperpixel": meta.get("kmperpixel"), + "transform": dbr.get("transform"), + "zerovalue": dbr.get("zerovalue"), + "threshold": dbr.get("threshold"), + "nonzero_mean_db": dbr.get("nonzero_mean"), + "nonzero_stdev_db": dbr.get("nonzero_stdev"), + "rain_fraction": dbr.get("nonzero_fraction"), + "mean_db": dbr.get("mean"), + "stdev_db": dbr.get("stdev"), + "nonzero_mean_rain": rain.get("nonzero_mean"), + "nonzero_stdev_rain": rain.get("nonzero_stdev"), + "mean_rain": rain.get("mean"), + "stdev_rain": rain.get("stdev"), + "psd": pspec.get("psd"), + "psd_bins": pspec.get("psd_bins"), + "c1": model.get("c1"), + "c2": model.get("c2"), + "scale_break": model.get("scale_break"), + "beta_1": model.get("beta_1"), + "beta_2": model.get("beta_2"), + "cdf": cdf_data.get("cdf"), + "cdf_bins": cdf_data.get("cdf_bins"), + "corl_zero": cascade.get("corl_zero"), + "cascade_stds": cascade.get("stds"), + "cascade_means": cascade.get("means"), + "cascade_lag1": cascade.get("lag1"), + "cascade_lag2": cascade.get("lag2"), + } + + # Make sure the times are UTC + ttime = kwargs["valid_time"] + if ttime is not None and ttime.tzinfo is None: + ttime = ttime.replace(tzinfo=datetime.timezone.utc) + kwargs["valid_time"] = ttime + + ttime = kwargs["base_time"] + if ttime is not None and ttime.tzinfo is None: + ttime = ttime.replace(tzinfo=datetime.timezone.utc) + kwargs["base_time"] = ttime + + # Add cascade_corl explicitly, since it's constructed dynamically + lag1_list = cascade.get("lag1", []) + kwargs["cascade_corl"] = [np.nan] * len(lag1_list) if lag1_list else None + + return cls(**kwargs) + + def to_dict(self) -> Dict[str, Any]: + return { + "dbr_stats": { + "transform": self.transform, + "zerovalue": self.zerovalue, + "threshold": self.threshold, + "nonzero_mean": self.nonzero_mean_db, + "nonzero_stdev": self.nonzero_stdev_db, + "nonzero_fraction": self.rain_fraction, + "mean": self.mean_db, + "stdev": self.stdev_db, + }, + "rain_stats": { + "nonzero_mean": self.nonzero_mean_rain, + "nonzero_stdev": self.nonzero_stdev_rain, + "nonzero_fraction": self.rain_fraction, # assume same as dbr_stats + "mean": self.mean_rain, + "stdev": self.stdev_rain, + "transform": None, + "zerovalue": 0, + "threshold": 0.1, + }, + "power_spectrum": { + "psd": self.psd, + "psd_bins": self.psd_bins, + "model": ( + { + "beta_1": self.beta_1, + "beta_2": self.beta_2, + "c1": self.c1, + "c2": self.c2, + "scale_break": self.scale_break, + } + if any( + x is not None + for x in [ + self.beta_1, + self.beta_2, + self.c1, + self.c2, + self.scale_break, + ] + ) + else None + ), + }, + "cdf": { + "cdf": self.cdf, + "cdf_bins": self.cdf_bins, + }, + "cascade": ( + { + "corl_zero": self.corl_zero, + "stds": self.cascade_stds, + "means": self.cascade_means, + "lag1": self.cascade_lag1, + "lag2": self.cascade_lag2, + "corl": self.cascade_corl, + } + if self.cascade_stds is not None + else None + ), + "metadata": { + "domain": self.domain, + "product": self.product, + "valid_time": self.valid_time, + "base_time": self.base_time, + "ensemble": self.ensemble, + "filename": self.filename, + "kmperpixel": self.kmperpixel, + }, + } + + +def compute_field_parameters( + db_data: np.ndarray, db_metadata: dict, scalebreak: Optional[float] = None +): + """ + Compute STEPS parameters for the dB transformed rainfall field + + Args: + db_data (np.ndarray): 2D field of dB-transformed rain. + db_metadata (dict): pysteps metadata dictionary. + + Returns: + dict: Dictionary containing STEPS parameters. + """ + + ps_dataset, ps_model = power_spectrum_1D(db_data, scalebreak) + if ps_dataset is not None: + power_spectrum = { + "psd": ps_dataset.psd.values.tolist(), + "psd_bins": ps_dataset.psd_bins.values.tolist(), + "model": ps_model, + } + else: + power_spectrum = {} + + # Compute cumulative probability distribution + cdf_dataset = prob_dist(db_data, db_metadata) + cdf = { + "cdf": cdf_dataset.cdf.values.tolist(), + "cdf_bins": cdf_dataset.cdf_bins.values.tolist(), + } + + # Store parameters in a dictionary + field_params = {"power_spectrum": power_spectrum, "cdf": cdf} + return field_params + + +def power_spectrum_1D( + field: np.ndarray, scale_break: Optional[float] = None +) -> Tuple[Optional[xr.Dataset], Optional[Dict[str, float]]]: + """ + Calculate the 1D isotropic power spectrum and fit a power law model. + + Args: + field (np.ndarray): 2D input field in [rows, columns] order. + scale_break (float, optional): Scale break in pixel units. If None, fit single line. + + Returns: + ps_dataset (xarray.Dataset): 1D isotropic power spectrum in dB. + model_params (dict): Dictionary with model parameters: beta_1, beta_2, c1, c2, scale_break + """ + min_stdev = 0.1 + mean = np.nanmean(field) + stdev = np.nanstd(field) + if stdev < min_stdev: + return None, None + + norm_field = (field - mean) / stdev + np.nan_to_num(norm_field, copy=False) + + field_fft = np.fft.rfft2(norm_field) + power_spectrum = np.abs(field_fft) ** 2 + + freq_x = np.fft.fftfreq(field.shape[1]) + freq_y = np.fft.fftfreq(field.shape[0]) + freq_r = np.sqrt(freq_x[:, None] ** 2 + freq_y[None, :] ** 2) + freq_r = freq_r[: field.shape[0] // 2, : field.shape[1] // 2] + power_spectrum = power_spectrum[: field.shape[0] // 2, : field.shape[1] // 2] + + n_bins = power_spectrum.shape[0] + bins = np.logspace( + np.log10(freq_r.min() + 1 / n_bins), np.log10(freq_r.max()), num=n_bins + ) + bin_centers = (bins[:-1] + bins[1:]) / 2 + power_1d = np.zeros(len(bin_centers)) + + for i in range(len(bins) - 1): + mask = (freq_r >= bins[i]) & (freq_r < bins[i + 1]) + power_1d[i] = np.nanmean(power_spectrum[mask]) if np.any(mask) else np.nan + + valid = (bin_centers > 0) & (~np.isnan(power_1d)) + bin_centers = bin_centers[valid] + power_1d = power_1d[valid] + + if len(bin_centers) == 0: + return None, None + + log_x = 10 * np.log10(bin_centers) + log_y = 10 * np.log10(power_1d) + + start_idx = 2 + end_idx = np.searchsorted(log_x, -4.0) + + model_params = {} + + if scale_break is None: + + def str_line(X, m, c): + return m * X + c + + popt, _ = curve_fit( + str_line, log_x[start_idx:end_idx], log_y[start_idx:end_idx] + ) + beta_1, c1 = popt + beta_2 = None + c2 = None + sb_log = None + else: + sb_freq = 1.0 / scale_break + sb_log = 10 * np.log10(sb_freq) + + def piecewise_linear(x, m1, m2, c1): + c2 = (m1 - m2) * sb_log + c1 + return np.where(x <= sb_log, m1 * x + c1, m2 * x + c2) + + popt, _ = curve_fit( + piecewise_linear, log_x[start_idx:end_idx], log_y[start_idx:end_idx] + ) + beta_1, beta_2, c1 = popt + c2 = (beta_1 - beta_2) * sb_log + c1 + + ps_dataset = xr.Dataset( + {"psd": (["bin"], log_y)}, + coords={"psd_bins": (["bin"], log_x)}, + attrs={"description": "1-D Isotropic power spectrum", "units": "dB"}, + ) + + model_params = { + "beta_1": float(beta_1), + "beta_2": float(beta_2) if beta_2 is not None else None, + "c1": float(c1), + "c2": float(c2) if c2 is not None else None, + "scale_break": float(scale_break) if scale_break is not None else None, + } + + return ps_dataset, model_params + + +def prob_dist(data: np.ndarray, metadata: dict): + """ + Calculate the cumulative probability distribution for rain > threshold for dB field + + Args: + data (np.ndarray): 2D field of dB-transformed rain. + metadata (dict): pysteps metadata dictionary. + + Returns: + tuple: + - xarray Dataset containing the cumulative probability distribution and bin edges + - fraction of field with rain > threshold (float) + """ + + rain_mask = data > metadata["zerovalue"] + + # Compute cumulative probability distribution + min_db = metadata["threshold"] + max_db = 10 * np.log10(MAX_RAIN_RATE) + bin_edges = np.linspace(min_db, max_db, N_BINS) + + # Histogram of rain values + hist, _ = np.histogram(data[rain_mask], bins=bin_edges, density=True) + + # Compute cumulative distribution + cumulative_distr = np.cumsum(hist) / np.sum(hist) + + # Create an xarray Dataset to store both cumulative distribution and bin edges + cdf_dataset = xr.Dataset( + { + "cdf": (["bin"], cumulative_distr), + }, + coords={ + # bin_edges[:-1] to match the histogram bins + "cdf_bins": (["bin"], bin_edges[:-1]), + }, + attrs={ + "description": "Cumulative probability distribution of rain rates", + "units": "dB", + }, + ) + + return cdf_dataset + + +def compute_field_stats(data, geodata): + nonzero_mask = data >= geodata["threshold"] + nonzero_mean = np.mean(data[nonzero_mask]) if np.any(nonzero_mask) else np.nan + nonzero_stdev = np.std(data[nonzero_mask]) if np.any(nonzero_mask) else np.nan + nonzero_frac = np.sum(nonzero_mask) / data.size + mean_rain = np.nanmean(data) + stdev_rain = np.nanstd(data) + + rain_stats = { + "nonzero_mean": float(nonzero_mean) if nonzero_mean is not None else None, + "nonzero_stdev": float(nonzero_stdev) if nonzero_stdev is not None else None, + "nonzero_fraction": float(nonzero_frac) if nonzero_frac is not None else None, + "mean": float(mean_rain) if mean_rain is not None else None, + "stdev": float(stdev_rain) if stdev_rain is not None else None, + "transform": geodata["transform"], + "zerovalue": geodata["zerovalue"], + "threshold": geodata["threshold"], + } + return rain_stats + + +def is_stationary(phi1, phi2): + return abs(phi2) < 1 and (phi1 + phi2) < 1 and (phi2 - phi1) < 1 + + +def correlation_length( + lag1: float, lag2: float, dx=10, tol=1e-4, max_lag=1000 +) -> float: + """ + Calculate the correlation length in minutes assuming AR(2) process + Args: + lag1 (float): Lag 1 auto-correltion + lag2 (float): Lag 2 auto-correlation + dx (int, optional): time step between lag1 & 2 in minutes. Defaults to 10. + tol (float, optional): _description_. Defaults to 1e-4. + max_lag (int, optional): _description_. Defaults to 1000. + + Returns: + corl (float): Correlation length in minutes + np.nan on error + """ + if lag1 is None or lag2 is None: + return np.nan + + A = np.array([[1.0, lag1], [lag1, 1.0]]) + b = np.array([lag1, lag2]) + + try: + phi = np.linalg.solve(A, b) + except np.linalg.LinAlgError: + return np.nan + + phi1, phi2 = phi + if not is_stationary(phi1, phi2): + return np.nan + + rho_vals = [1.0, lag1, lag2] + for _ in range(3, max_lag): + next_rho = phi1 * rho_vals[-1] + phi2 * rho_vals[-2] + if abs(next_rho) < tol: + break + rho_vals.append(next_rho) + corl = np.trapz(rho_vals, dx=dx) + return float(corl) + + +def power_law_acor( + config: Dict[str, Any], T_ref: float +) -> tuple[np.ndarray, np.ndarray]: + """ + Compute lag-1 and lag-2 autocorrelations for each cascade level using a power-law model. + + Args: + config (dict): Configuration dictionary with 'pysteps.timestep' (in seconds) + and 'dynamic_scaling' parameters. + T_ref (float): Reference correlation length T(t, L) at the largest scale (in minutes). + + Returns: + np.ndarray: Array of shape (n_levels, 2) with [lag1, lag2] for each level. + np.ndarray: Array of corelation lengths per level + """ + dt_seconds = config["pysteps"]["timestep"] + dt_mins = dt_seconds / 60.0 + + ds_config = config.get("dynamic_scaling", {}) + scales = ds_config["central_wave_lengths"] + ht = ds_config["space_time_exponent"] + a = ds_config["lag2_constants"] + b = ds_config["lag2_exponents"] + + L = scales[0] + T_levels = [T_ref * (l / L) ** ht for l in scales] + + lags = np.empty((len(scales), 2), dtype=np.float32) + for ia, T_l in enumerate(T_levels): + pl_lag1 = np.exp(-dt_mins / T_l) + pl_lag2 = a[ia] * (pl_lag1 ** b[ia]) + lags[ia, 0] = pl_lag1 + lags[ia, 1] = pl_lag2 + + levels = np.array(T_levels) + return lags, levels diff --git a/pysteps/param/shared_utils.py b/pysteps/param/shared_utils.py index 2d37199fa..f513c1754 100644 --- a/pysteps/param/shared_utils.py +++ b/pysteps/param/shared_utils.py @@ -1,131 +1,115 @@ +from collections.abc import Callable + import datetime +import copy import logging import numpy as np -from pysteps.cascade.decomposition import decomposition_fft, recompose_fft -from pysteps.timeseries import autoregression -from pysteps import extrapolation -from models.mongo_access import get_config, get_db -from models.steps_params import StochasticRainParameters -from models.stochastic_generator import gen_stoch_field, normalize_db_field - -def initialize_config(base_time_str, name): - try: - base_time = datetime.datetime.fromisoformat(base_time_str).replace(tzinfo=datetime.timezone.utc) - except ValueError: - raise ValueError(f"Invalid base time format: {base_time_str}") - - db = get_db() - config = get_config(db, name) - if config is None: - raise RuntimeError(f"Configuration not found for domain {name}") +import pandas as pd - return db, config, base_time +from statsmodels.tsa.api import SimpleExpSmoothing +from pysteps.cascade.decomposition import decomposition_fft, recompose_fft +from pysteps.timeseries.autoregression import ( + adjust_lag2_corrcoef2, + estimate_ar_params_yw, +) +from pysteps import extrapolation -def prepare_forecast_loop(db, config, base_time, name, product): - print(f"Running {product} for domain {name} at {base_time}") - # Placeholder for forecast generation logic - # This is where you'd insert the time loop, forecast logic, and output handling - pass +from utils.transformer import DBTransformer +from steps_params import StepsParameters +from stochastic_generator import gen_stoch_field, normalize_db_field +from rainfield_stats import correlation_length +from rainfield_stats import power_spectrum_1D +from cascade_utils import lagr_auto_cor -def update_field(cascades: list, optical_flow: np.ndarray, params: StochasticRainParameters, bp_filter: dict, config: dict) -> np.ndarray: +def update_field( + cascades: list, + oflow: np.ndarray, + params: StepsParameters, + bp_filter: dict, + config: dict, + dom: dict, +) -> np.ndarray: """ Update a rainfall field using the parametric STEPS algorithm. + Assumes that the cascades list has the correct number of valid cascades Args: - cascades (list): List of rainfall cascades for previous ar_order time steps. - optical_flow (np.ndarray): Optical flow field for Lagrangian updates. - params (StochasticRainParameters): Parameters for the update. + cascades (list): List of 1 or 2 cascades for initial conditions + oflow(np.ndarray): Optical flow array + params (StepsParameters): Parameters for the update. bp_filter: Bandpass filter dictionary returned by pysteps.cascade.bandpass_filters.filter_gaussian - config: The configuration dictionary + config: The configuration dictionary + dom: the domain dictionary Returns: - np.ndarray: Updated rainfall field in decibels (dB) of rain intensity + np.ndarray: Updated rainfall field in decibels (dB) of rain intensity """ - ar_order = config['pysteps']['ar_order'] - n_levels = config['pysteps']['n_cascade_levels'] - n_rows = config['domain']['n_rows'] - n_cols = config['domain']['n_cols'] - - # Ensure that we have valid input parameters - number_none_states = sum(1 for v in cascades if v is None) - if (number_none_states != 0) or (optical_flow is None) or (params is None): - logging.debug( - "Missing cascade values, skipping forecast.") - return None - - # Calculate the AR phi parameters, check if there any cascade parameters - if params.cascade_lag1 is None: - logging.debug( - "No valid cascade lag1 values found in the parameters. Skipping forecast.") - return None - if ar_order == 2 and params.cascade_lag2 is None: - logging.debug( - "No valid cascade lag2 values found in the parameters. Skipping forecast.") - return None - - # Check if the lag 1 and lag 2 are all valid - number_none_lag1 = sum(1 for v in params.cascade_lag1 if np.isnan(v)) - number_none_lag2 = 0 - if ar_order == 2: - number_none_lag2 = sum(1 for v in params.cascade_lag2 if np.isnan(v)) - - # Fill the lag1 and lag2 with the default parameters - if number_none_lag1 != 0 or number_none_lag2 != 0: - params.corl_zero = config["dynamic_scaling"]["cor_len_pvals"][1] - params.calc_acor(config) + ar_order = config["ar_order"] + n_levels = config["n_cascade_levels"] + n_rows = dom["n_rows"] + n_cols = dom["n_cols"] + + scale_break_km = config["scale_break"] + kmperpixel = config["kmperpixel"] + + rain_threshold = config["precip_threshold"] + db_threshold = 10 * np.log10(rain_threshold) + transformer = DBTransformer(rain_threshold) + zerovalue = transformer.zerovalue + + # Set up the AR(2) parameters phi = np.zeros((n_levels, ar_order + 1)) for ilev in range(n_levels): - gamma_1 = params.cascade_lag1[ilev] + gamma_1 = params.lag_1[ilev] + gamma_2 = params.lag_2[ilev] if ar_order == 2: - gamma_2 = autoregression.adjust_lag2_corrcoef2( - gamma_1, params.cascade_lag2[ilev]) - phi[ilev] = autoregression.estimate_ar_params_yw( - [gamma_1, gamma_2]) + gamma_2 = adjust_lag2_corrcoef2(gamma_1, gamma_2) + phi[ilev] = estimate_ar_params_yw([gamma_1, gamma_2]) else: - phi[ilev] = autoregression.estimate_ar_params_yw( - [gamma_1]) + phi[ilev] = estimate_ar_params_yw([gamma_1]) # Generate the noise field and cascade - noise_field = gen_stoch_field(params, n_cols, n_rows) - max_dbr = 10*np.log10(150) - min_dbr = 10*np.log10(0.05) - noise_field = np.clip(noise_field, min_dbr, max_dbr) + noise_field = gen_stoch_field( + params, n_cols, n_rows, kmperpixel, scale_break_km, db_threshold + ) noise_cascade = decomposition_fft( - noise_field, bp_filter, compute_stats=True, normalize=True) + noise_field, bp_filter, compute_stats=True, normalize=True + ) # Update the cascade extrapolation_method = extrapolation.get_method("semilagrangian") lag_0 = np.zeros((n_levels, n_rows, n_cols)) - if ar_order == 1: - lag_1 = cascades[0]["cascade_levels"] - else: - lag_2 = cascades[0]["cascade_levels"] - lag_1 = cascades[1]["cascade_levels"] - # Loop over cascade levels - for ilev in range(n_levels): - # Set the outside pixels to zero - adv_lag1 = extrapolation_method( - lag_1[ilev], optical_flow, 1, outval=0)[0] - if ar_order == 1: - lag_0[ilev] = phi[ilev, 0] * adv_lag1 + \ - phi[ilev, 1] * noise_cascade["cascade_levels"][ilev] + if ar_order == 2: + lag_1 = copy.deepcopy(cascades[0]["cascade_levels"]) + lag_2 = copy.deepcopy(cascades[1]["cascade_levels"]) - else: - # Set the outside pixels to zero - adv_lag2 = extrapolation_method( - lag_2[ilev], optical_flow, 2, outval=0)[1] - lag_0[ilev] = phi[ilev, 0] * adv_lag1 + phi[ilev, 1] * \ - adv_lag2 + phi[ilev, 2] * noise_cascade["cascade_levels"][ilev] - - # Make sure we have mean = 0, stdev = 1 - lev_mean = np.mean(lag_0) - lev_stdev = np.std(lag_0) - if lev_stdev > 1e-1: - lag_0 = (lag_0 - lev_mean)/lev_stdev + for ilev in range(n_levels): + adv_lag2 = extrapolation_method(lag_2[ilev], oflow, 2, outval=0)[1] + adv_lag1 = extrapolation_method(lag_1[ilev], oflow, 1, outval=0)[0] + lag_0[ilev] = ( + phi[ilev, 0] * adv_lag1 + + phi[ilev, 1] * adv_lag2 + + phi[ilev, 2] * noise_cascade["cascade_levels"][ilev] + ) + + else: + lag_1 = copy.deepcopy(cascades[0]["cascade_levels"]) + for ilev in range(n_levels): + adv_lag1 = extrapolation_method(lag_1[ilev], oflow, 1, outval=0)[0] + lag_0[ilev] = ( + phi[ilev, 0] * adv_lag1 + + phi[ilev, 1] * noise_cascade["cascade_levels"][ilev] + ) + + # Make sure we have mean = 0, stdev = 1 + lev_mean = np.mean(lag_0) + lev_stdev = np.std(lag_0) + if lev_stdev > 1e-1: + lag_0 = (lag_0 - lev_mean) / lev_stdev # Recompose the cascade into a single field updated_cascade = {} @@ -137,45 +121,465 @@ def update_field(cascades: list, optical_flow: np.ndarray, params: StochasticRai # Use the noise cascade level stds updated_cascade["means"] = noise_cascade["means"].copy() updated_cascade["stds"] = noise_cascade["stds"].copy() - - # Reduce the bias in the last cascade level due to the gradient in rain / no rain - high_freq_bias = 0.80 - updated_cascade["stds"][-1] *= high_freq_bias gen_field = recompose_fft(updated_cascade) # Normalise the field to have the expected conditional mean and variance - norm_field = normalize_db_field(gen_field, params) + norm_field = normalize_db_field(gen_field, params, db_threshold, zerovalue) return norm_field -def zero_state(config): - n_cascade_levels = config['pysteps']['n_cascade_levels'] - n_rows = config['domain']['n_rows'] - n_cols = config['domain']['n_cols'] + +def zero_state(config, domain): + n_cascade_levels = config["n_cascade_levels"] + n_rows = domain["n_rows"] + n_cols = domain["n_cols"] metadata_dict = { - "transform": config['pysteps']['transform'], - "threshold": config['pysteps']['threshold'], - "zerovalue": config['pysteps']['zerovalue'], + "transform": None, + "threshold": None, + "zerovalue": None, "mean": float(0), "std_dev": float(0), - "wetted_area_ratio": float(0) + "wetted_area_ratio": float(0), } cascade_dict = { "cascade_levels": np.zeros((n_cascade_levels, n_rows, n_cols)), "means": np.zeros(n_cascade_levels), "stds": np.zeros(n_cascade_levels), - "domain": 'spatial', + "domain": "spatial", "normalized": True, } oflow = np.zeros((2, n_rows, n_cols)) - state = { - "cascade": cascade_dict, - "optical_flow": oflow, - "metadata": metadata_dict - } + state = {"cascade": cascade_dict, "optical_flow": oflow, "metadata": metadata_dict} return state def is_zero_state(state, tol=1e-6): return abs(state["metadata"]["mean"]) < tol + +# climatology of the parameters for radar QPE 9000 sets of parameters in Auckland +# 95% 50% 5% +# nonzero_mean_db 6.883147 4.590082 2.815397 +# nonzero_stdev_db 3.793680 2.489131 1.298552 +# rain_fraction 0.447717 0.048889 0.008789 +# beta_1 -0.452957 -1.681647 -2.726216 +# beta_2 -2.322891 -3.251342 -4.009131 +# corl_zero 1074.976508 188.058276 23.489147 + + +def qc_params(ens_df, config): + """ + Apply QC to the 'param' column in the ensemble DataFrame. + The DataFrame is assumed to have 'valid_time' as the index. + + Smooth corl_zero using exponential smoothing and recompute cascade autocorrelations. + Clamp smoothed parameters to climatological bounds. + Returns a deep-copied DataFrame with corrected parameters. + """ + var_list = [ + "nonzero_mean_db", + "nonzero_stdev_db", + "rain_fraction", + "beta_1", + "beta_2", + ] + var_lower = [2.81, 1.30, 0.0, -2.73, -4.01] + var_upper = [9.50, 5.00, 1.0, -2.05, -2.32] + + qc_df = ens_df.copy(deep=True) + qc_dict = {} + + # Smooth each variable and clamp to bounds + for iv, var in enumerate(var_list): + x_list = [ + ( + np.nan + if qc_df.at[idx, "param"].get(var) is None + else qc_df.at[idx, "param"].get(var) + ) + for idx in qc_df.index + ] + + model = SimpleExpSmoothing(x_list, initialization_method="estimated").fit( + smoothing_level=0.10, optimized=False + ) + qc_dict[var] = np.clip(model.fittedvalues, var_lower[iv], var_upper[iv]) + + # Extract correlation length thresholds from config + corl_pvals = config["dynamic_scaling"]["cor_len_pvals"] + corl_min = min(corl_pvals) + corl_max = max(corl_pvals) + corl_def = corl_pvals[1] # median + + # Prepare and smooth corl_zero + corl_list = [] + for idx in qc_df.index: + corl = qc_df.at[idx, "param"].get("corl_zero", corl_def) + corl = corl_def if corl is None else max(corl_min, min(corl, corl_max)) + corl_list.append(corl) + + model = SimpleExpSmoothing(corl_list, initialization_method="estimated").fit( + smoothing_level=0.1, optimized=False + ) + qc_dict["corl_zero"] = model.fittedvalues + + # Assign smoothed parameters and compute lags + for i, idx in enumerate(qc_df.index): + param = copy.deepcopy(qc_df.at[idx, "param"]) + + # Ensure valid spectral slope order + if qc_dict["beta_2"][i] > qc_dict["beta_1"][i]: + qc_dict["beta_2"][i] = qc_dict["beta_1"][i] + + # Assign smoothed & clamped values + for var in var_list: + setattr(param, var, qc_dict[var][i]) + param.corl_zero = qc_dict["corl_zero"][i] + + # Compute lag-1 and lag-2 for this correlation length + lags, _ = calc_auto_corls(config, param.corl_zero) + param.lag_1 = list(lags[:, 0]) + param.lag_2 = list(lags[:, 1]) + + # Save updated object + qc_df.at[idx, "param"] = param + + return qc_df + + +def blend_param(qpe_params, nwp_params, param_names, weight): + for pname in param_names: + + qval = getattr(qpe_params, pname, None) + nval = getattr(nwp_params, pname, None) + if isinstance(qval, (int, float)) and isinstance(nval, (int, float)): + setattr(nwp_params, pname, weight * qval + (1 - weight) * nval) + elif ( + isinstance(qval, list) and isinstance(nval, list) and len(qval) == len(nval) + ): + setattr( + nwp_params, + pname, + [weight * q + (1 - weight) * n for q, n in zip(qval, nval)], + ) + return nwp_params + + +def blend_parameters( + config: dict[str, object], + blend_base_time: datetime.datetime, + nwp_param_df: pd.DataFrame, + rad_param: StepsParameters, + weight_fn: Callable[[float], float] | None = None, +) -> pd.DataFrame: + """ + Function to blend the radar and NWP parameters + + Args: + config (dict): Configuration dictionary + blend_base_time (datetime.datetime): Time of the radar parameter set + nwp_param_df (pd.DataFrame): Dataframe of valid_time and parameters, + with valid_time as index and of type datetime.datetime + rad_param (StochasticRainParameters): Parameter object with radar parameters + weight_fn (Optional[Callable[[float], float]], optional): _description_. Defaults to None. + + Returns: + pd.DataFrame: _description_ + """ + + def default_weight_fn(lag_sec: float) -> float: + return np.exp(-((lag_sec / 10800) ** 2)) # 3h Gaussian + + if weight_fn is None: + weight_fn = default_weight_fn + + blended_param_names = [ + "nonzero_mean_db", + "nonzero_stdev_db", + "rain_fraction", + "beta_1", + "beta_2", + "corl_zero", + ] + blended_df = copy.deepcopy(nwp_param_df) + for vtime in blended_df.index: + lag_sec = (vtime - blend_base_time).total_seconds() + weight = weight_fn(lag_sec) + + # Select the parameter object for this vtime and blend + original = blended_df.loc[vtime, "param"] + clean_original = copy.deepcopy(original) + updated = blend_param(rad_param, clean_original, blended_param_names, weight) + + # Compute lag-1 and lag-2 for this correlation length + lags, _ = calc_auto_corls(config, updated.corl_zero) + updated.lag_1 = list(lags[:, 0]) + updated.lag_2 = list(lags[:, 1]) + + blended_df.loc[vtime, "param"] = updated + + return blended_df + + +def fill_param_gaps( + ens_df: pd.DataFrame, forecast_times: list[datetime.datetime] +) -> pd.DataFrame: + """ + Fill gaps in the time series of parameters with the most recent *original* observation + if the gap is smaller than a threshold. + + Assumes that all the parameters have the same domain, product, base_time, ensemble. + + Args: + ens_df (pd.DataFrame): DataFrame with columns 'valid_time' and 'param'. + forecast_times (list): List of datetime.datetime in UTC. + + Returns: + pd.DataFrame: DataFrame with gaps filled. + """ + max_gap = datetime.timedelta(hours=6) + + ens_df = ens_df.copy() + ens_df["valid_time"] = pd.to_datetime(ens_df["valid_time"], utc=True) + ens_df = ens_df.sort_values("valid_time").reset_index(drop=True) + + filled_map = dict(zip(ens_df["valid_time"], ens_df["param"])) + original_times = set(ens_df["valid_time"]) + + # Extract default metadata + first_param = ens_df.iloc[0].at["param"] + def_metadata_base = first_param.metadata.copy() + + for vtime in forecast_times: + if vtime in filled_map: + continue + + metadata = def_metadata_base.copy() + metadata["valid_time"] = vtime + + # Find the nearest valid time + if original_times: + nearest_time = min(original_times, key=lambda t: abs(t - vtime)) + gap = abs(nearest_time - vtime) + + if gap <= max_gap: + logging.debug( + f"Filling {vtime} with params from nearest time {nearest_time} (gap = {gap})" + ) + def_param = copy.deepcopy(filled_map[nearest_time]) + def_param.metadata = metadata + def_param.rain_fraction = 0 + else: + logging.debug( + f"Nearest gap too large to fill for {vtime}, using default" + ) + def_param = StepsParameters(metadata=metadata) + else: + logging.debug(f"No valid parameter found near {vtime}, using default") + def_param = StepsParameters(metadata=metadata) + + filled_map[vtime] = def_param + + records = [{"valid_time": t, "param": p} for t, p in sorted(filled_map.items())] + return pd.DataFrame(records) + + +def calc_auto_corls(config: dict, T_ref: float) -> tuple[np.ndarray, np.ndarray]: + """ + Compute lag-1 and lag-2 autocorrelations for each cascade level using a power-law model. + + Args: + config (dict): Configuration dictionary with 'pysteps.timestep' (in seconds) + and 'dynamic_scaling' parameters. + T_ref (float): Reference correlation length T(t, L) at the largest scale (in minutes). + + Returns: + np.ndarray: Array of shape (n_levels, 2) with [lag1, lag2] for each level. + np.ndarray: Array of corelation lengths per level + """ + dt_seconds = config["timestep"] + dt_mins = dt_seconds / 60.0 + + ds_config = config.get("dynamic_scaling", {}) + scales = ds_config["central_wave_lengths"] + ht = ds_config["space_time_exponent"] + a = ds_config["lag2_constants"] + b = ds_config["lag2_exponents"] + + L = scales[0] + T_levels = [T_ref * (l / L) ** ht for l in scales] + + lags = np.empty((len(scales), 2), dtype=np.float32) + for ia, T_l in enumerate(T_levels): + pl_lag1 = np.exp(-dt_mins / T_l) + pl_lag2 = a[ia] * (pl_lag1 ** b[ia]) + lags[ia, 0] = pl_lag1 + lags[ia, 1] = pl_lag2 + + levels = np.array(T_levels) + return lags, levels + + +def fit_auto_cors( + clen: float, + alpha: float, + d_mins: int, + *, + allow_negative: bool = False, + return_diagnostics: bool = False, +): + """ + Find lag1, lag2 (with lag2 = lag1**alpha) such that + correlation_length(lag1, lag2, d_mins) ~= clen. + + Parameters + ---------- + clen : float + Target correlation length (minutes), must be > 0. + alpha : float + Exponent linking lag2 and lag1 via lag2 = lag1**alpha. + d_mins : int + Time step between lag1 and lag2 (minutes). + allow_negative : bool, optional + If True, search lag1 in (-1, 1). Otherwise restrict to (0, 1). + return_diagnostics : bool, optional + If True, also return achieved correlation length and absolute error. + + Returns + ------- + lag1 : float + lag2 : float + (achieved_clen, abs_error) : tuple[float, float], only if return_diagnostics=True + """ + if not np.isfinite(clen) or clen <= 0: + raise ValueError("clen must be a positive, finite number.") + if not np.isfinite(alpha): + raise ValueError("alpha must be finite.") + if not np.isfinite(d_mins) or d_mins <= 0: + raise ValueError("d_mins must be a positive, finite number.") + + # Stability / search bounds for lag1 + eps = 1e-3 + lo, hi = (-0.999999, 0.999999) if allow_negative else (eps, 0.999999) + + # Objective: squared error on correlation length + def _obj(l1: float) -> float: + # Quick rejection of out-of-bounds + if not (lo < l1 < hi): + return np.inf + l2 = l1**alpha + + # Keep |lag2| < 1 as well to stay in a stable region + if not (abs(l2) < 1.0): + return np.inf + + # Make sure that we have a valid lag1, lag2 combination + l2 = adjust_lag2_corrcoef2(l1, l2) + + c = correlation_length(l1, l2, d_mins) + if not np.isfinite(c): + return np.inf + return (c - clen) ** 2 + + # Try SciPy first + lag1 = None + try: + from scipy.optimize import minimize_scalar # type: ignore + + res = minimize_scalar( + _obj, bounds=(lo, hi), method="bounded", options={"xatol": 1e-10} + ) + lag1 = float(res.x) + except Exception: + # Pure NumPy golden-section fallback + phi = (1.0 + np.sqrt(5.0)) / 2.0 + a, b = lo, hi + c = b - (b - a) / phi + d = a + (b - a) / phi + fc = _obj(c) + fd = _obj(d) + # Max ~100 iterations gives ~1e-8 bracket typically + for _ in range(100): + if fc < fd: + b, d, fd = d, c, fc + c = b - (b - a) / phi + fc = _obj(c) + else: + a, c, fc = c, d, fd + d = a + (b - a) / phi + fd = _obj(d) + if (b - a) < 1e-10: + break + lag1 = float((a + b) / 2.0) + + lag2 = float(lag1**alpha) + achieved = float(correlation_length(lag1, lag2, d_mins)) + err = abs(achieved - clen) + + if return_diagnostics: + return lag1, lag2, (achieved, err) + return lag1, lag2 + + +def calc_corls(scales, czero, ht): + # Power law function for correlation length + corls = [czero] + lzero = scales[0] + for scale in scales[1:]: + corl = czero * (scale / lzero) ** ht + corls.append(corl) + return corls + + +def calculate_parameters( + db_field: np.ndarray, + cascades: dict, + oflow: np.ndarray, + scale_break: float, + zero_value: float, + dt: int, +): + p_dict = {} + + # Probability distribution moments + nonzero_mask = db_field > zero_value + p_dict["nonzero_mean_db"] = ( + np.mean(db_field[nonzero_mask]) if np.any(nonzero_mask) else np.nan + ) + p_dict["nonzero_stdev_db"] = ( + np.std(db_field[nonzero_mask]) if np.any(nonzero_mask) else np.nan + ) + p_dict["rain_fraction"] = np.sum(nonzero_mask) / db_field.size + + # Power spectrum slopes + _, ps_model = power_spectrum_1D(db_field, scale_break) + if ps_model: + p_dict["beta_1"] = ps_model.get("beta_1", -2.05) + p_dict["beta_2"] = ps_model.get("beta_2", -3.2) + else: + p_dict["beta_1"] = -2.05 + p_dict["beta_2"] = -3.2 + + # Stack the (k,m,n) arrays in order t-2, t-1, t0 to get (t,k,m,n) array + data = [] + for ia in range(3): + data.append(cascades[ia]["cascade_levels"]) + data = np.stack(data) + a_corls = lagr_auto_cor(data, oflow) + n_levels = a_corls.shape[0] + lag_1 = [] + lag_2 = [] + clens = [] + for ilag in range(n_levels): + r1 = float(a_corls[ilag][0]) + r2 = float(a_corls[ilag][1]) + clen = correlation_length(r1, r2, dt) + lag_1.append(r1) + lag_2.append(r2) + clens.append(clen) + + p_dict["lag_1"] = lag_1 + p_dict["lag_2"] = lag_2 + p_dict["corl_zero"] = clens[0] + + return StepsParameters.from_dict(p_dict) diff --git a/pysteps/param/steps_params.py b/pysteps/param/steps_params.py index 16763008d..aaa9a56c8 100644 --- a/pysteps/param/steps_params.py +++ b/pysteps/param/steps_params.py @@ -1,555 +1,124 @@ -# Contains: StochasticRainParameters (dataclass, from_dict, to_dict), compute_field_parameters, compute_field_stats -""" - Functions to implement the parametric version of STEPS -""" -from typing import Optional, Tuple, Dict, Union, List, Callable +from dataclasses import dataclass, field import datetime -import copy -import logging -from typing import Optional, List, Dict, Any -from dataclasses import dataclass -import xarray as xr -from scipy.optimize import curve_fit -import numpy as np -import pandas as pd -MAX_RAIN_RATE = 250 -N_BINS = 200 +# 95% 50% 5% +# nonzero_mean_db 6.883147 4.590082 2.815397 +# nonzero_stdev_db 3.793680 2.489131 1.298552 +# rain_fraction 0.447717 0.048889 0.008789 +# beta_1 -0.452957 -1.681647 -2.726216 +# beta_2 -2.322891 -3.251342 -4.009131 +# corl_zero 1074.976508 188.058276 23.489147 -@dataclass -class StochasticRainParameters: - transform: Optional[str] = None - zerovalue: Optional[float] = None - threshold: Optional[float] = None - kmperpixel: Optional[float] = None - - mean_db: Optional[float] = None - stdev_db: Optional[float] = None - nonzero_mean_rain: Optional[float] = None - nonzero_stdev_rain: Optional[float] = None - mean_rain: Optional[float] = None - stdev_rain: Optional[float] = None - - psd: Optional[List[float]] = None - psd_bins: Optional[List[float]] = None - c1: Optional[float] = None - c2: Optional[float] = None - scale_break: Optional[float] = None - cdf: Optional[List[float]] = None - cdf_bins: Optional[List[float]] = None - cascade_stds: Optional[List[float]] = None - cascade_means: Optional[List[float]] = None - cascade_lag1: Optional[List[float]] = None - cascade_lag2: Optional[List[float]] = None - cascade_corl: Optional[List[float]] = None - product: Optional[str] = None - valid_time: Optional[datetime.datetime] = None - base_time: Optional[datetime.datetime] = None - ensemble: Optional[int] = None - field_id: Optional[str] = None +@dataclass +class StepsParameters: + metadata: dict - # Defaulted parameters - nonzero_mean_db: float = 2.3 - nonzero_stdev_db: float = 5.6 + # STEPS parameters with defaults for light rain + nonzero_mean_db: float = 2.81 + nonzero_stdev_db: float = 1.3 rain_fraction: float = 0 - beta_1: float = -2.06 - beta_2: float = 3.2 - corl_zero: float = 260 - - def get(self, key: str, default: Any = None) -> Any: - """Mimic dict.get() for selected attributes.""" - return getattr(self, key, default) - - def calc_corl(self): - """Populate the correlation lengths using lag1 and lag2 values.""" - if self.cascade_lag1 is None or self.cascade_lag2 is None: - return - - n_levels = len(self.cascade_lag1) - if len(self.cascade_corl) != n_levels: - self.cascade_corl = [np.nan] * n_levels - - for ilev in range(n_levels): - lag1 = self.cascade_lag1[ilev] - lag2 = self.cascade_lag2[ilev] - self.cascade_corl[ilev] = correlation_length(lag1, lag2) - - # Convenience for blending with radar - self.corl_zero = self.cascade_corl[0] + beta_1: float = -2.05 + beta_2: float = -3.2 + corl_zero: float = 180 + + # Auto-correlation lists + lag_1: list[float] = field(default_factory=list) + lag_2: list[float] = field(default_factory=list) + + # Required metadata keys + _required_metadata_keys = { + "domain", + "product", + "valid_time", + "base_time", + "ensemble", + } - def calc_acor(self, config) -> None: - T_ref = self.corl_zero - if T_ref is None or np.isnan(T_ref): - T_ref = config["dynamic_scaling"]["cor_len_pvals"][1] - - acor, corl = power_law_acor(config, T_ref) - self.cascade_corl = [float(x) for x in corl] - self.cascade_lag1 = [float(x) for x in acor[:, 0]] - self.cascade_lag2 = [float(x) for x in acor[:, 1]] + def get(self, key: str, default=None): + """Mimic dict.get(). Check metadata first, then top-level attributes.""" + if key in self.metadata: + value = self.metadata.get(key) + if ( + value is None + and key in self._required_metadata_keys + and default is None + ): + raise KeyError(f"Required metadata key '{key}' is missing or None.") + return value if value is not None else default + else: + return getattr(self, key, default) + + def set_metadata(self, key: str, value): + """Set a metadata key/value pair and validate if required.""" + self.metadata[key] = value + if key in self._required_metadata_keys and value is None: + raise ValueError(f"Required metadata key '{key}' cannot be None.") + + def validate(self): + """Raise ValueError if any required field is missing or None.""" + for key in self._required_metadata_keys: + if key not in self.metadata or self.metadata[key] is None: + raise ValueError(f"Missing required metadata field: '{key}'") @classmethod + def from_dict(cls, data: dict): + """Create a StepsParameters object from a dictionary.""" + + def ensure_utc(dt): + if dt is None: + return None + if isinstance(dt, str): + dt = datetime.datetime.fromisoformat(dt) + if dt.tzinfo is None: + return dt.replace(tzinfo=datetime.timezone.utc) + return dt.astimezone(datetime.timezone.utc) - def from_dict(cls, data: Dict[str, Any]) -> "StochasticRainParameters": - dbr = data.get("dbr_stats", {}) - rain = data.get("rain_stats", {}) - pspec = data.get("power_spectrum", {}) - model = pspec.get("model", {}) if pspec else {} - cdf_data = data.get("cdf", {}) - cascade = data.get("cascade", {}) meta = data.get("metadata", {}) + if meta is not None: + metadata = { + "domain": meta.get("domain"), + "product": meta.get("product"), + "valid_time": ensure_utc(meta.get("valid_time")), + "base_time": ensure_utc(meta.get("base_time")), + "ensemble": meta.get("ensemble"), + } + else: + metadata = {} return cls( - product=meta.get("product"), - valid_time=meta.get("valid_time"), - base_time=meta.get("base_time"), - ensemble=meta.get("ensemble"), - field_id=meta.get("field_id"), - transform=dbr.get("transform"), - zerovalue=dbr.get("zerovalue"), - threshold=dbr.get("threshold"), - kmperpixel=meta.get("kmperpixel"), - - nonzero_mean_db=dbr.get("nonzero_mean"), - nonzero_stdev_db=dbr.get("nonzero_stdev"), - rain_fraction=dbr.get("nonzero_fraction"), - mean_db=dbr.get("mean"), - stdev_db=dbr.get("stdev"), - nonzero_mean_rain=rain.get("nonzero_mean"), - nonzero_stdev_rain=rain.get("nonzero_stdev"), - mean_rain=rain.get("mean"), - stdev_rain=rain.get("stdev"), - - psd=pspec.get("psd", []), - psd_bins=pspec.get("psd_bins", []), - - beta_1=model.get("beta_1"), - beta_2=model.get("beta_2"), - c1=model.get("c1"), - c2=model.get("c2"), - scale_break=model.get("scale_break"), - - cdf=cdf_data.get("cdf", []), - cdf_bins=cdf_data.get("cdf_bins", []), - - corl_zero=cascade.get("corl_zero"), - cascade_stds=cascade.get("stds"), - cascade_means=cascade.get("means"), - cascade_lag1=cascade.get("lag1"), - cascade_lag2=cascade.get("lag2"), - cascade_corl=[np.nan] * len(cascade.get("lag1", [])) + metadata=metadata, + nonzero_mean_db=data.get("nonzero_mean_db", 2.81), + nonzero_stdev_db=data.get("nonzero_stdev_db", 1.3), + rain_fraction=data.get("rain_fraction", 0), + beta_1=data.get("beta_1", -2.05), + beta_2=data.get("beta_2", -3.2), + corl_zero=data.get("corl_zero", 180), + lag_1=data.get("lag_1", []), + lag_2=data.get("lag_2", []), ) - def to_dict(self) -> Dict[str, Any]: - return { - "dbr_stats": { - "transform": self.transform, - "zerovalue": self.zerovalue, - "threshold": self.threshold, - "nonzero_mean": self.nonzero_mean_db, - "nonzero_stdev": self.nonzero_stdev_db, - "nonzero_fraction": self.rain_fraction, - "mean": self.mean_db, - "stdev": self.stdev_db, - }, - "rain_stats": { - "nonzero_mean": self.nonzero_mean_rain, - "nonzero_stdev": self.nonzero_stdev_rain, - "nonzero_fraction": self.rain_fraction, # assume same as dbr_stats - "mean": self.mean_rain, - "stdev": self.stdev_rain, - "transform": None, - "zerovalue": 0, - "threshold": 0.1, - }, - "power_spectrum": { - "psd": self.psd, - "psd_bins": self.psd_bins, - "model": { - "beta_1": self.beta_1, - "beta_2": self.beta_2, - "c1": self.c1, - "c2": self.c2, - "scale_break": self.scale_break, - } if any(x is not None for x in [self.beta_1, self.beta_2, self.c1, self.c2, self.scale_break]) else None - }, - "cdf": { - "cdf": self.cdf, - "cdf_bins": self.cdf_bins, - }, - "cascade": { - "corl_zero":self.corl_zero, - "stds": self.cascade_stds, - "means": self.cascade_means, - "lag1": self.cascade_lag1, - "lag2": self.cascade_lag2, - "corl": self.cascade_corl, - } if self.cascade_stds is not None else None, - "metadata": { - "kmperpixel": self.kmperpixel, - "product": self.product, - "valid_time": self.valid_time, - "base_time": self.base_time, - "ensemble": self.ensemble, - "field_id": self.field_id, + def to_dict(self): + """Convert the object into a dictionary for MongoDB or JSON.""" + if self.metadata is not None: + metadata = { + "domain": self.metadata["domain"], + "product": self.metadata["product"], + "valid_time": self.metadata["valid_time"], + "base_time": self.metadata["base_time"], + "ensemble": self.metadata["ensemble"], } - } - - -def compute_field_parameters(db_data: np.ndarray, db_metadata: dict, scale_break_km: Optional[float] = None): - """ - Compute STEPS parameters for the dB transformed rainfall field - - Args: - db_data (np.ndarray): 2D field of dB-transformed rain. - db_metadata (dict): pysteps metadata dictionary. - - Returns: - dict: Dictionary containing STEPS parameters. - """ - - # Compute power spectrum model - if scale_break_km is not None: - scalebreak = scale_break_km * 1000.0 / db_metadata["xpixelsize"] - else: - scalebreak = None - ps_dataset, ps_model = power_spectrum_1D(db_data, scalebreak) - power_spectrum = { - "psd": ps_dataset.psd.values.tolist(), - "psd_bins": ps_dataset.psd_bins.values.tolist(), - "model": ps_model - } - - # Compute cumulative probability distribution - cdf_dataset = prob_dist(db_data, db_metadata) - cdf = { - "cdf": cdf_dataset.cdf.values.tolist(), - "cdf_bins": cdf_dataset.cdf_bins.values.tolist(), - } - - # Store parameters in a dictionary - steps_params = { - "timestamp": datetime.datetime.now(datetime.timezone.utc), - "power_spectrum": power_spectrum, - "cdf": cdf - } - return steps_params - - -def power_spectrum_1D(field: np.ndarray, scale_break: Optional[float] = None - ) -> Tuple[Optional[xr.Dataset], Optional[Dict[str, float]]]: - """ - Calculate the 1D isotropic power spectrum and fit a power law model. - - Args: - field (np.ndarray): 2D input field in [rows, columns] order. - scale_break (float, optional): Scale break in pixel units. If None, fit single line. - - Returns: - ps_dataset (xarray.Dataset): 1D isotropic power spectrum in dB. - model_params (dict): Dictionary with model parameters: beta_1, beta_2, c1, c2, scale_break - """ - min_stdev = 0.1 - mean = np.nanmean(field) - stdev = np.nanstd(field) - if stdev < min_stdev: - return None, None - - norm_field = (field - mean) / stdev - np.nan_to_num(norm_field, copy=False) - - field_fft = np.fft.rfft2(norm_field) - power_spectrum = np.abs(field_fft) ** 2 - - freq_x = np.fft.fftfreq(field.shape[1]) - freq_y = np.fft.fftfreq(field.shape[0]) - freq_r = np.sqrt(freq_x[:, None]**2 + freq_y[None, :]**2) - freq_r = freq_r[: field.shape[0] // 2, : field.shape[1] // 2] - power_spectrum = power_spectrum[: field.shape[0] // - 2, : field.shape[1] // 2] - - n_bins = power_spectrum.shape[0] - bins = np.logspace(np.log10(freq_r.min() + 1 / n_bins), - np.log10(freq_r.max()), num=n_bins) - bin_centers = (bins[:-1] + bins[1:]) / 2 - power_1d = np.zeros(len(bin_centers)) - - for i in range(len(bins) - 1): - mask = (freq_r >= bins[i]) & (freq_r < bins[i + 1]) - power_1d[i] = np.nanmean( - power_spectrum[mask]) if np.any(mask) else np.nan - - valid = (bin_centers > 0) & (~np.isnan(power_1d)) - bin_centers = bin_centers[valid] - power_1d = power_1d[valid] - - if len(bin_centers) == 0: - return None, None - - log_x = 10*np.log10(bin_centers) - log_y = 10*np.log10(power_1d) - - start_idx = 2 - end_idx = np.searchsorted(log_x, -4.0) - - model_params = {} - - if scale_break is None: - def str_line(X, m, c): return m * X + c - popt, _ = curve_fit( - str_line, log_x[start_idx:end_idx], log_y[start_idx:end_idx]) - beta_1, c1 = popt - beta_2 = None - c2 = None - sb_log = None - else: - sb_freq = 1.0 / scale_break - sb_log = 10*np.log10(sb_freq) - - def piecewise_linear(x, m1, m2, c1): - c2 = (m1 - m2) * sb_log + c1 - return np.where(x <= sb_log, m1 * x + c1, m2 * x + c2) - - popt, _ = curve_fit( - piecewise_linear, log_x[start_idx:end_idx], log_y[start_idx:end_idx]) - beta_1, beta_2, c1 = popt - c2 = (beta_1 - beta_2) * sb_log + c1 - - ps_dataset = xr.Dataset( - {"psd": (["bin"], log_y)}, - coords={"psd_bins": (["bin"], log_x)}, - attrs={"description": "1-D Isotropic power spectrum", "units": "dB"} - ) + else: + metadata = {} - model_params = { - "beta_1": float(beta_1), - "beta_2": float(beta_2), - "c1": float(c1), - "c2": float(c2), - "scale_break": float(scale_break) - } - - return ps_dataset, model_params - - -def prob_dist(data: np.ndarray, metadata: dict): - """ - Calculate the cumulative probability distribution for rain > threshold for dB field - - Args: - data (np.ndarray): 2D field of dB-transformed rain. - metadata (dict): pysteps metadata dictionary. - - Returns: - tuple: - - xarray Dataset containing the cumulative probability distribution and bin edges - - fraction of field with rain > threshold (float) - """ - - rain_mask = data > metadata["zerovalue"] - - # Compute cumulative probability distribution - min_db = metadata["zerovalue"] + 0.1 - max_db = 10 * np.log10(MAX_RAIN_RATE) - bin_edges = np.linspace(min_db, max_db, N_BINS) - - # Histogram of rain values - hist, _ = np.histogram(data[rain_mask], bins=bin_edges, density=True) - - # Compute cumulative distribution - cumulative_distr = np.cumsum(hist) / np.sum(hist) - - # Create an xarray Dataset to store both cumulative distribution and bin edges - cdf_dataset = xr.Dataset( - { - "cdf": (["bin"], cumulative_distr), - }, - coords={ - # bin_edges[:-1] to match the histogram bins - "cdf_bins": (["bin"], bin_edges[:-1]), - }, - attrs={ - "description": "Cumulative probability distribution of rain rates", - "units": "dB", + return { + "metadata": metadata, + "nonzero_mean_db": self.nonzero_mean_db, + "nonzero_stdev_db": self.nonzero_stdev_db, + "rain_fraction": self.rain_fraction, + "beta_1": self.beta_1, + "beta_2": self.beta_2, + "corl_zero": self.corl_zero, + "lag_1": self.lag_1, + "lag_2": self.lag_2, } - ) - - return cdf_dataset - - -def compute_field_stats(data, geodata): - nonzero_mask = data > geodata["zerovalue"] - nonzero_mean = np.mean(data[nonzero_mask]) if np.any( - nonzero_mask) else np.nan - nonzero_stdev = np.std(data[nonzero_mask]) if np.any( - nonzero_mask) else np.nan - nonzero_frac = np.sum(nonzero_mask) / data.size - mean_rain = np.nanmean(data) - stdev_rain = np.nanstd(data) - - rain_stats = { - "nonzero_mean": float(nonzero_mean) if nonzero_mean is not None else None, - "nonzero_stdev": float(nonzero_stdev) if nonzero_stdev is not None else None, - "nonzero_fraction": float(nonzero_frac) if nonzero_frac is not None else None, - "mean": float(mean_rain) if mean_rain is not None else None, - "stdev": float(stdev_rain) if stdev_rain is not None else None, - "transform": geodata["transform"], - "zerovalue": geodata["zerovalue"], - "threshold": geodata["threshold"] - } - return rain_stats - - -def get_param_by_key( - params_df: pd.DataFrame, - valid_time: datetime.datetime, - base_time: Optional[datetime.datetime] = None, - ensemble: Optional[Union[int, str]] = None, - strict: bool = False -) -> Optional[StochasticRainParameters]: - """ - Retrieve the StochasticRainParameters object from a DataFrame index. - - Uses 'NA' as sentinel for missing base_time/ensemble. - - Args: - params_df (pd.DataFrame): Indexed by (valid_time, base_time, ensemble). - valid_time (datetime): Required valid_time. - base_time (datetime or None): Optional base_time. - ensemble (int, str, or None): Optional ensemble. - strict (bool): Raise KeyError if not found (default: False = return None) - - Returns: - StochasticRainParameters or None - """ - base_time = base_time if base_time is not None else "NA" - ensemble = ensemble if ensemble is not None else "NA" - try: - return params_df.loc[(valid_time, base_time, ensemble), "param"] - except KeyError: - if strict: - raise - return None - - -def is_stationary(phi1, phi2): - return abs(phi2) < 1 and (phi1 + phi2) < 1 and (phi2 - phi1) < 1 - - -def correlation_length(lag1: float, lag2: float, dx=10, tol=1e-4, max_lag=1000): - """ - Calculate the correlation length in minutes assuming AR(2) process - Args: - lag1 (float): Lag 1 auto-correltion - lag2 (float): Lag 2 auto-correlation - dx (int, optional): time step between lag1 & 2 in minutes. Defaults to 10. - tol (float, optional): _description_. Defaults to 1e-4. - max_lag (int, optional): _description_. Defaults to 1000. - - Returns: - corl (float): Correlation length in minutes - np.nan on error - """ - if lag1 is None or lag2 is None: - return np.nan - - A = np.array([[1.0, lag1], [lag1, 1.0]]) - b = np.array([lag1, lag2]) - - try: - phi = np.linalg.solve(A, b) - except np.linalg.LinAlgError: - return np.nan - - phi1, phi2 = phi - if not is_stationary(phi1, phi2): - return np.nan - - rho_vals = [1.0, lag1, lag2] - for _ in range(3, max_lag): - next_rho = phi1 * rho_vals[-1] + phi2 * rho_vals[-2] - if abs(next_rho) < tol: - break - rho_vals.append(next_rho) - corl = np.trapz(rho_vals, dx=dx) - return corl - - -def power_law_acor(config: Dict[str, Any], T_ref: float) -> np.ndarray: - """ - Compute lag-1 and lag-2 autocorrelations for each cascade level using a power-law model. - - Args: - config (dict): Configuration dictionary with 'pysteps.timestep' (in seconds) - and 'dynamic_scaling' parameters. - T_ref (float): Reference correlation length T(t, L) at the largest scale (in minutes). - - Returns: - np.ndarray: Array of shape (n_levels, 2) with [lag1, lag2] for each level. - np.ndarray: Array of corelation lengths per level - """ - dt_seconds = config["pysteps"]["timestep"] - dt_mins = dt_seconds / 60.0 - - ds_config = config.get("dynamic_scaling", {}) - scales = ds_config["central_wave_lengths"] - ht = ds_config["space_time_exponent"] - a = ds_config["lag2_constants"] - b = ds_config["lag2_exponents"] - - L = scales[0] - T_levels = [T_ref * (l / L) ** ht for l in scales] - - lags = np.empty((len(scales), 2), dtype=np.float32) - for ia, T_l in enumerate(T_levels): - pl_lag1 = np.exp(-dt_mins / T_l) - pl_lag2 = a[ia] * (pl_lag1 ** b[ia]) - lags[ia, 0] = pl_lag1 - lags[ia, 1] = pl_lag2 - - return lags, T_levels - -def blend_param(qpe_params, nwp_params, param_names, weight): - for pname in param_names: - - qval = getattr(qpe_params, pname, None) - nval = getattr(nwp_params, pname, None) - if isinstance(qval, (int, float)) and isinstance(nval, (int, float)): - setattr(nwp_params, pname, weight * qval + (1 - weight) * nval) - elif isinstance(qval, list) and isinstance(nval, list) and len(qval) == len(nval): - setattr(nwp_params, pname, [ - weight * q + (1 - weight) * n for q, n in zip(qval, nval)]) - return nwp_params - - -def blend_parameters(config, blend_base_time: datetime.datetime, nwp_param_df: pd.DataFrame, rad_param: StochasticRainParameters, - weight_fn: Callable[[float], float] = None - ) -> pd.DataFrame: - - if weight_fn is None: - def weight_fn(lag_sec): return np.exp(-(lag_sec / 10800) - ** 2) # 3h Gaussian - blended_param_names = [ - "nonzero_mean_db", - "nonzero_stdev_db", - "rain_fraction", - "beta_1", - "beta_2", - "corl_zero" - ] - blended_df = copy.deepcopy(nwp_param_df) - for vtime in blended_df.index: - lag_sec = (vtime - blend_base_time).total_seconds() - weight = weight_fn(lag_sec) - - # Select the parameter object for this vtime and blend - original = blended_df.loc[vtime, "param"] - clean_original = copy.deepcopy(original) - updated = blend_param(rad_param, clean_original, blended_param_names, weight) - - # Update the auto-correlations using the dynamic scaling parameters - updated.calc_acor(config) - blended_df.loc[vtime, "param"] = updated - - return blended_df - diff --git a/pysteps/param/stochastic_generator.py b/pysteps/param/stochastic_generator.py index 15bd178d4..bc610a0d7 100644 --- a/pysteps/param/stochastic_generator.py +++ b/pysteps/param/stochastic_generator.py @@ -2,24 +2,33 @@ from typing import Optional import numpy as np from scipy import interpolate, stats -from models import StochasticRainParameters +from steps_params import StepsParameters -def gen_stoch_field(steps_params: StochasticRainParameters, nx: int, ny: int): + +def gen_stoch_field( + steps_params: StepsParameters, + nx: int, + ny: int, + pixel_size: float, + scale_break: float, + threshold: float, +): """ - Generate a rain field with normal distribution and a power law power spectrum + Generate a rain field with normal distribution and a power law power spectrum Args: - steps_params (StochasticRainParameters): The dataclass with all the steps parameters + steps_params (StepsParameters): The dataclass with the steps parameters nx (int): x dimension of the output field ny (int): y dimension of the output field + kmperpixel (float): pixel size + scale_break (float): scale break in km + threshold (float): rain threshold in db Returns: np.ndarray: Output field with shape (ny,nx) """ - beta_1 = steps_params.beta_1 - beta_2 = steps_params.beta_2 - pixel_size = steps_params.kmperpixel - scale_break = pixel_size * steps_params.scale_break + beta_1 = steps_params.beta_1 + beta_2 = steps_params.beta_2 # generate uniform random numbers in the range 0,1 y = np.random.uniform(low=0, high=1, size=(ny, nx)) @@ -30,11 +39,12 @@ def gen_stoch_field(steps_params: StochasticRainParameters, nx: int, ny: int): out_fft = fft * filter out_field = np.fft.ifft2(out_fft).real - nbins = 250 + nbins = 500 + eps = 0.001 + res = stats.cumfreq(out_field, numbins=nbins) - bins = [res.lowerlimit + ia * - res.binsize for ia in range(1+res.cumcount.size)] - count = res.cumcount / res.cumcount[nbins-1] + bins = [res.lowerlimit + ia * res.binsize for ia in range(1 + res.cumcount.size)] + count = res.cumcount / res.cumcount[nbins - 1] # find the threshold value for this non-rain probability rain_bin = 0 @@ -49,48 +59,52 @@ def gen_stoch_field(steps_params: StochasticRainParameters, nx: int, ny: int): norm_data = out_field - rain_threshold # Now we need to transform the "raining" samples to have the desired distribution - rain_mask = norm_data > steps_params.threshold + rain_mask = norm_data > threshold rain_obs = norm_data[rain_mask] rain_res = stats.cumfreq(rain_obs, numbins=nbins) - rain_bins = [rain_res.lowerlimit + ia * - rain_res.binsize for ia in range(1+rain_res.cumcount.size)] - rain_cdf = rain_res.cumcount / rain_res.cumcount[nbins-1] + rain_bins = [ + rain_res.lowerlimit + ia * rain_res.binsize + for ia in range(1 + rain_res.cumcount.size) + ] + rain_cdf = rain_res.cumcount / rain_res.cumcount[nbins - 1] # rain_bins are the bin edges; use bin centers for interpolation bin_centers = 0.5 * (np.array(rain_bins[:-1]) + np.array(rain_bins[1:])) # Step 1: Build LUT: map empirical CDF → target normal quantiles # Make sure rain_cdf values are in (0,1) to avoid issues with extreme tails - eps = 1e-6 rain_cdf_clipped = np.clip(rain_cdf, eps, 1 - eps) - # Map rain_cdf quantiles to corresponding values in the target normal distribution - target_mu = steps_params.nonzero_mean_db - target_sigma = steps_params.nonzero_stdev_db * 0.80 - normal_values = stats.norm.ppf( - rain_cdf_clipped, loc=target_mu, scale=target_sigma) + # Map rain_cdf quantiles to corresponding values in the target normal distribution + target_mu = steps_params.nonzero_mean_db + target_sigma = steps_params.nonzero_stdev_db + normal_values = stats.norm.ppf(rain_cdf_clipped, loc=target_mu, scale=target_sigma) # Create interpolation function from observed rain values to target normal values cdf_transform = interpolate.interp1d( - bin_centers, normal_values, - kind="linear", bounds_error=False, - fill_value=(normal_values[0], normal_values[-1]) + bin_centers, + normal_values, + kind="linear", + bounds_error=False, + fill_value=(normal_values[0], normal_values[-1]), # type: ignore ) - # Transform raining pixels + # Transform pdf of the raining pixels norm_data[rain_mask] = cdf_transform(norm_data[rain_mask]) + return norm_data -def normalize_db_field(data, params): +def normalize_db_field(data, params, threshold, zerovalue): if params.rain_fraction < 0.025: - return np.full_like(data, params.zerovalue) - - nbins = 250 + return np.full_like(data, zerovalue) + + nbins = 500 + eps = 0.0001 + res = stats.cumfreq(data, numbins=nbins) - bins = [res.lowerlimit + ia * - res.binsize for ia in range(1+res.cumcount.size)] - count = res.cumcount / res.cumcount[nbins-1] + bins = [res.lowerlimit + ia * res.binsize for ia in range(1 + res.cumcount.size)] + count = res.cumcount / res.cumcount[nbins - 1] # find the threshold value for this non-rain probability rain_bin = 0 @@ -99,57 +113,73 @@ def normalize_db_field(data, params): rain_bin = ia else: break - rain_threshold = bins[rain_bin+1] + rain_threshold = bins[rain_bin + 1] # Shift the data to have the correct probability of rain - norm_data = data + (params.threshold - rain_threshold) + norm_data = data + (threshold - rain_threshold) # Now we need to transform the raining samples to have the desired distribution # Get the sample distribution - rain_mask = norm_data > params.threshold + rain_mask = norm_data > threshold rain_obs = norm_data[rain_mask] rain_res = stats.cumfreq(rain_obs, numbins=nbins) - rain_bins = [rain_res.lowerlimit + ia * - rain_res.binsize for ia in range(1+rain_res.cumcount.size)] - rain_cdf = rain_res.cumcount / rain_res.cumcount[nbins-1] + + rain_bins = [ + rain_res.lowerlimit + ia * rain_res.binsize + for ia in range(1 + rain_res.cumcount.size) + ] + rain_cdf = rain_res.cumcount / rain_res.cumcount[nbins - 1] # rain_bins are the bin edges; use bin centers for interpolation bin_centers = 0.5 * (np.array(rain_bins[:-1]) + np.array(rain_bins[1:])) # Step 1: Build LUT: map empirical CDF → target normal quantiles # Make sure rain_cdf values are in (0,1) to avoid issues with extreme tails - eps = 5e-3 rain_cdf_clipped = np.clip(rain_cdf, eps, 1 - eps) # Map rain_cdf quantiles to corresponding values in the target normal distribution - # We need to reduce the bias in the output fields - bias_adj = 0.85 + # We need to reduce the bias in the output fields target_mu = params.nonzero_mean_db - target_sigma = params.nonzero_stdev_db * bias_adj - normal_values = stats.norm.ppf( - rain_cdf_clipped, loc=target_mu, scale=target_sigma) + target_sigma = params.nonzero_stdev_db + normal_values = stats.norm.ppf(rain_cdf_clipped, loc=target_mu, scale=target_sigma) # Create interpolation function from observed rain values to target normal values + fill_value = (normal_values[0], normal_values[-1]) cdf_transform = interpolate.interp1d( - bin_centers, normal_values, - kind="linear", bounds_error=False, - fill_value=(normal_values[0], normal_values[-1]) + bin_centers, + normal_values, + kind="linear", + bounds_error=False, + fill_value=fill_value, # type: ignore ) # Transform raining pixels norm_data[rain_mask] = cdf_transform(norm_data[rain_mask]) - return norm_data -def pl_filter(beta_1: float, nx: int, ny: int, pixel_size: float, beta_2: Optional[float] = None, scale_break: Optional[float] = None, - ): + # Check if we have nans and return zerovalue if yes + has_nan = np.isnan(norm_data).any() + if has_nan: + return np.full_like(data, zerovalue) + else: + return norm_data + + +def pl_filter( + beta_1: float, + nx: int, + ny: int, + pixel_size: float, + beta_2: Optional[float] = None, + scale_break: Optional[float] = None, +): """ Generate a 2D low-pass power-law filter for FFT filtering. Parameters: - beta_1 (float): Power law exponent for frequencies < f1 (low frequencies) + beta_1 (float): Power law exponent for frequencies < f1 (low frequencies) nx (int): Number of columns (width) in the 2D field ny (int): Number of rows (height) in the 2D field - pixel_size (float): Pixel size in km + pixel_size (float): Pixel size in km beta_2 (float): Power law exponent for frequencies > f1 (high frequencies) Optional scale_break (float): Break scale in km Optional @@ -168,20 +198,20 @@ def pl_filter(beta_1: float, nx: int, ny: int, pixel_size: float, beta_2: Option filter_r = np.ones_like(freq_r) # Initialize with ones f_zero = freq_x[1] - if beta_2 is not None: + if beta_2 is not None and scale_break is not None: b1 = beta_1 / 2.0 - b2 = (beta_2-0.3) / 2.0 + b2 = beta_2 / 2.0 f1 = 1 / scale_break # Convert scale break to frequency domain - weight = (f1/f_zero) ** b1 + weight = (f1 / f_zero) ** b1 # Apply the power-law function for a **low-pass filter** # Handle division by zero at freq = 0 - with np.errstate(divide='ignore', invalid='ignore'): + with np.errstate(divide="ignore", invalid="ignore"): mask_low = freq_r < f1 # Frequencies lower than the break mask_high = ~mask_low # Frequencies higher than or equal to the break - filter_r[mask_low] = (freq_r[mask_low]/f_zero) ** b1 + filter_r[mask_low] = (freq_r[mask_low] / f_zero) ** b1 filter_r[mask_high] = weight * (freq_r[mask_high] / f1) ** b2 # Ensure DC component (zero frequency) is handled properly @@ -189,7 +219,7 @@ def pl_filter(beta_1: float, nx: int, ny: int, pixel_size: float, beta_2: Option else: b1 = beta_1 / 2.0 mask = freq_r > 0 - filter_r[mask] = (freq_r[mask]/f_zero) ** b1 + filter_r[mask] = (freq_r[mask] / f_zero) ** b1 filter_r[freq_r == 0] = 1 # Preserve the mean component return filter_r diff --git a/pysteps/param/transformer.py b/pysteps/param/transformer.py new file mode 100644 index 000000000..7e5325c02 --- /dev/null +++ b/pysteps/param/transformer.py @@ -0,0 +1,187 @@ +import numpy as np +import scipy.stats as scipy_stats +from scipy.interpolate import interp1d +from typing import Optional + + +class BaseTransformer: + def __init__(self, threshold: float = 0.5, zerovalue: Optional[float] = None): + self.threshold = threshold + self.zerovalue = zerovalue + self.metadata = {} + + def transform(self, R: np.ndarray) -> np.ndarray: + raise NotImplementedError + + def inverse_transform(self, R: np.ndarray) -> np.ndarray: + raise NotImplementedError + + def get_metadata(self) -> dict: + return self.metadata.copy() + + +class DBTransformer(BaseTransformer): + """ + DBTransformer applies a thresholded dB transform to rain rate fields. + + Parameters: + threshold (float): Rain rate threshold (in mm/h). Values below this are set to `zerovalue` in dB. + zerovalue (Optional[float]): Value in dB space to assign below-threshold pixels. If None, defaults to log10(threshold) - 0.1 + """ + + def __init__(self, threshold: float = 0.5, zerovalue: Optional[float] = None): + super().__init__(threshold, zerovalue) + threshold_db = 10.0 * np.log10(self.threshold) + + if self.zerovalue is None: + self.zerovalue = threshold_db - 0.1 + + self.metadata = { + "transform": "dB", + "threshold": self.threshold, # stored in mm/h + "zerovalue": self.zerovalue, # stored in dB + } + + def transform(self, R: np.ndarray) -> np.ndarray: + R = R.copy() + mask = R < self.threshold + R[~mask] = 10.0 * np.log10(R[~mask]) + R[mask] = self.zerovalue + return R + + def inverse_transform(self, R: np.ndarray) -> np.ndarray: + R = R.copy() + R = 10.0 ** (R / 10.0) + R[R < self.threshold] = 0 + return R + + +class BoxCoxTransformer(BaseTransformer): + def __init__(self, Lambda: float = 0.0, **kwargs): + super().__init__(**kwargs) + self.Lambda = Lambda + + def transform(self, R: np.ndarray) -> np.ndarray: + R = R.copy() + mask = R < self.threshold + + if self.Lambda == 0.0: + R[~mask] = np.log(R[~mask]) + tval = np.log(self.threshold) + else: + R[~mask] = (R[~mask] ** self.Lambda - 1) / self.Lambda + tval = (self.threshold**self.Lambda - 1) / self.Lambda + + if self.zerovalue is None: + self.zerovalue = tval - 1 + + R[mask] = self.zerovalue + + self.metadata = { + "transform": "BoxCox", + "lambda": self.Lambda, + "threshold": tval, + "zerovalue": self.zerovalue, + } + return R + + def inverse_transform(self, R: np.ndarray) -> np.ndarray: + R = R.copy() + if self.Lambda == 0.0: + R = np.exp(R) + else: + R = np.exp(np.log(self.Lambda * R + 1) / self.Lambda) + + threshold_inv = ( + np.exp(np.log(self.Lambda * self.metadata["threshold"] + 1) / self.Lambda) + if self.Lambda != 0.0 + else np.exp(self.metadata["threshold"]) + ) + + R[R < threshold_inv] = self.metadata["zerovalue"] + self.metadata["transform"] = None + return R + + +class NQTransformer(BaseTransformer): + def __init__(self, a: float = 0.0, **kwargs): + super().__init__(**kwargs) + self.a = a + self._inverse_interp = None + + def transform(self, R: np.ndarray) -> np.ndarray: + R = R.copy() + shape = R.shape + R = R.ravel() + mask = ~np.isnan(R) + R_ = R[mask] + + n = R_.size + Rpp = (np.arange(n) + 1 - self.a) / (n + 1 - 2 * self.a) + Rqn = scipy_stats.norm.ppf(Rpp) + R_sorted = R_[np.argsort(R_)] + R_trans = np.interp(R_, R_sorted, Rqn) + + self.zerovalue = np.min(R_) + R_trans[R_ == self.zerovalue] = 0 + + self._inverse_interp = interp1d( + Rqn, + R_sorted, + bounds_error=False, + fill_value=(float(R_sorted.min()), float(R_sorted.max())), # type: ignore + ) + + R[mask] = R_trans + R = R.reshape(shape) + + self.metadata = { + "transform": "NQT", + "threshold": R_trans[R_trans > 0].min(), + "zerovalue": 0, + } + return R + + def inverse_transform(self, R: np.ndarray) -> np.ndarray: + if self._inverse_interp is None: + raise RuntimeError("Must call transform() before inverse_transform()") + + R = R.copy() + shape = R.shape + R = R.ravel() + mask = ~np.isnan(R) + R[mask] = self._inverse_interp(R[mask]) + R = R.reshape(shape) + + self.metadata["transform"] = None + return R + + +class SqrtTransformer(BaseTransformer): + def transform(self, R: np.ndarray) -> np.ndarray: + R = np.sqrt(R) + self.metadata = { + "transform": "sqrt", + "threshold": np.sqrt(self.threshold), + "zerovalue": np.sqrt(self.zerovalue) if self.zerovalue else 0.0, + } + return R + + def inverse_transform(self, R: np.ndarray) -> np.ndarray: + R = R**2 + self.metadata["transform"] = None + return R + + +def get_transformer(name: str, **kwargs) -> BaseTransformer: + name = name.lower() + if name == "boxcox": + return BoxCoxTransformer(**kwargs) + elif name == "db": + return DBTransformer(**kwargs) + elif name == "nqt": + return NQTransformer(**kwargs) + elif name == "sqrt": + return SqrtTransformer(**kwargs) + else: + raise ValueError(f"Unknown transformer type: {name}") From 3355a3edea25f372b012b9c37c8646a4385ed48e Mon Sep 17 00:00:00 2001 From: Alan Seed Date: Fri, 29 Aug 2025 14:00:27 +1000 Subject: [PATCH 05/12] fixed module names --- pysteps/param/nc_utils.py | 2 +- pysteps/param/shared_utils.py | 12 ++++++------ pysteps/param/stochastic_generator.py | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pysteps/param/nc_utils.py b/pysteps/param/nc_utils.py index 17a304172..ff0604469 100644 --- a/pysteps/param/nc_utils.py +++ b/pysteps/param/nc_utils.py @@ -66,7 +66,7 @@ def read_qpe_netcdf(file_path): """ try: - ds = xr.open_dataset(file_path, decode_times=True) + ds = xr.open_dataset(file_path, decode_cf=True, mask_and_scale=True) ds.load() # Make the times timezone-aware UTC diff --git a/pysteps/param/shared_utils.py b/pysteps/param/shared_utils.py index f513c1754..4da386e90 100644 --- a/pysteps/param/shared_utils.py +++ b/pysteps/param/shared_utils.py @@ -15,12 +15,12 @@ ) from pysteps import extrapolation -from utils.transformer import DBTransformer -from steps_params import StepsParameters -from stochastic_generator import gen_stoch_field, normalize_db_field -from rainfield_stats import correlation_length -from rainfield_stats import power_spectrum_1D -from cascade_utils import lagr_auto_cor +from pysteps.utils.transformer import DBTransformer +from .steps_params import StepsParameters +from .stochastic_generator import gen_stoch_field, normalize_db_field +from .rainfield_stats import correlation_length +from .rainfield_stats import power_spectrum_1D +from .cascade_utils import lagr_auto_cor def update_field( diff --git a/pysteps/param/stochastic_generator.py b/pysteps/param/stochastic_generator.py index bc610a0d7..22cc68117 100644 --- a/pysteps/param/stochastic_generator.py +++ b/pysteps/param/stochastic_generator.py @@ -2,7 +2,7 @@ from typing import Optional import numpy as np from scipy import interpolate, stats -from steps_params import StepsParameters +from .steps_params import StepsParameters def gen_stoch_field( From 25ad0e739c3bceee5d5ef61a274b47655fe32525 Mon Sep 17 00:00:00 2001 From: Alan Seed Date: Tue, 9 Sep 2025 23:06:10 +1000 Subject: [PATCH 06/12] update modules --- pysteps/param/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pysteps/param/__init__.py b/pysteps/param/__init__.py index a48e10992..5eb5d7412 100644 --- a/pysteps/param/__init__.py +++ b/pysteps/param/__init__.py @@ -15,9 +15,8 @@ blend_parameters, zero_state, is_zero_state, - calc_corls, + calc_auto_cors, fit_auto_cors, calculate_parameters, ) -from .nc_utils import generate_geo_dict, generate_geo_dict_xy, read_qpe_netcdf from .transformer import DBTransformer From d13c079fb0bdf911cce682e476b2859f71a07135 Mon Sep 17 00:00:00 2001 From: Alan Seed Date: Tue, 9 Sep 2025 23:07:04 +1000 Subject: [PATCH 07/12] nc_utils in mongo module --- pysteps/mongo/nc_utils.py | 444 ++++++++++++++++++++++---------------- 1 file changed, 258 insertions(+), 186 deletions(-) diff --git a/pysteps/mongo/nc_utils.py b/pysteps/mongo/nc_utils.py index a8709ccb5..cc5690cf7 100644 --- a/pysteps/mongo/nc_utils.py +++ b/pysteps/mongo/nc_utils.py @@ -1,93 +1,126 @@ """ - Refactored IO utilities for pysteps. +Refactored IO utilities for pysteps. """ + import numpy as np from pyproj import CRS import netCDF4 -from datetime import datetime, timezone +import datetime from typing import Optional import io +from pathlib import Path + def replace_extension(filename: str, new_ext: str) -> str: return f"{filename.rsplit('.', 1)[0]}{new_ext}" + def convert_timestamps_to_datetimes(timestamps): """Convert POSIX timestamps to datetime objects.""" - return [datetime.fromtimestamp(ts, tz=timezone.utc) for ts in timestamps] - - -def write_netcdf(rain: np.ndarray, geo_data: dict, time: int): + return [ + datetime.datetime.fromtimestamp(ts, tz=datetime.timezone.utc) + for ts in timestamps + ] + + +def write_netcdf_file( + file_path: Path, + rain: np.ndarray, + geo_data: dict, + valid_times: list[datetime.datetime], + ensembles: list[int] | None, +) -> None: """ - Write rain data as a NetCDF4 memory buffer. - - :param buffer: A BytesIO buffer to store the NetCDF data. - :param rain: Rainfall data as a NumPy array. - :param geo_data: Dictionary containing geo-referencing data with keys: - 'x', 'y', 'projection', and other metadata. - :param time: POSIX timestamp representing the time dimension. - :return: The BytesIO buffer containing the NetCDF data. + Write a set of rainfall grids to a CF-compliant NetCDF file using i2 data and scale_factor. + + Args: + file_path (Path): Full path to the output file. + rain (np.ndarray): Rainfall array. Shape is [ensemble, time, y, x] if ensembles is provided, + otherwise [time, y, x], with units in mm/h as float. + geo_data (dict): Geospatial metadata (must include 'x', 'y', and optionally 'projection'). + valid_times (list[datetime.datetime]): List of timezone-aware valid times. + ensembles (list[int] | None): Optional list of ensemble member IDs. """ - x = geo_data['x'] - y = geo_data['y'] - # Default to WGS84 if not provided - projection = geo_data.get('projection', 'EPSG:4326') - - # Create an in-memory NetCDF dataset - ds = netCDF4.Dataset('inmemory.nc', mode='w', memory=1024) - - # Define dimensions - y_dim = ds.createDimension("y", len(y)) - x_dim = ds.createDimension("x", len(x)) - t_dim = ds.createDimension("time", 1) - - # Define coordinate variables - y_var = ds.createVariable("y", "f4", ("y",)) - x_var = ds.createVariable("x", "f4", ("x",)) - t_var = ds.createVariable("time", "i8", ("time",)) - - # Define rain variable - rain_var = ds.createVariable( - "rainfall", "i2", ("time", "y", "x"), zlib=True - ) # int16 with a fill value - rain_var.scale_factor = 0.1 - rain_var.add_offset = 0.0 - rain_var.units = "mm/h" - rain_var.long_name = "Rainfall rate" - rain_var.grid_mapping = "projection" - - # Assign coordinate values - y_var[:] = y - y_var.standard_name = "projection_y_coordinate" - y_var.units = "m" - - x_var[:] = x - x_var.standard_name = "projection_x_coordinate" - x_var.units = "m" - - t_var[:] = [time] - t_var.standard_name = "time" - t_var.units = "seconds since 1970-01-01T00:00:00Z" + # Convert datetime to seconds since epoch + time_stamps = [vt.timestamp() for vt in valid_times] + + x = geo_data["x"] + y = geo_data["y"] + projection = geo_data.get("projection", "EPSG:4326") + rain_fill_value = -1 + + with netCDF4.Dataset(file_path, mode="w", format="NETCDF4") as ds: + # Define dimensions + ds.createDimension("y", len(y)) + ds.createDimension("x", len(x)) + ds.createDimension("time", len(valid_times)) + if ensembles is not None: + ds.createDimension("ensemble", len(ensembles)) + + # Define coordinate variables + x_var = ds.createVariable("x", "f4", ("x",)) + x_var[:] = x + x_var.standard_name = "projection_x_coordinate" + x_var.units = "m" + + y_var = ds.createVariable("y", "f4", ("y",)) + y_var[:] = y + y_var.standard_name = "projection_y_coordinate" + y_var.units = "m" + + t_var = ds.createVariable("time", "f8", ("time",)) + t_var[:] = time_stamps + t_var.standard_name = "time" + t_var.units = "seconds since 1970-01-01T00:00:00Z" + t_var.calendar = "standard" + + if ensembles is not None: + e_var = ds.createVariable("ensemble", "i4", ("ensemble",)) + e_var[:] = ensembles + e_var.standard_name = "ensemble" + e_var.units = "1" + + # Define the rainfall variable with proper fill_value + rain_dims = ( + ("time", "y", "x") if ensembles is None else ("ensemble", "time", "y", "x") + ) + rain_var = ds.createVariable( + "rainfall", + "i2", + rain_dims, + zlib=True, + complevel=5, + fill_value=rain_fill_value, + ) + + # Scale and store rainfall + rain_scaled = np.where( + np.isnan(rain), rain_fill_value, np.round(rain * 10).astype(np.int16) + ) + rain_var[...] = rain_scaled + + # Metadata + rain_var.scale_factor = 0.1 + rain_var.add_offset = 0.0 + rain_var.units = "mm/h" + rain_var.long_name = "Rainfall rate" + rain_var.grid_mapping = "projection" + rain_var.coordinates = " ".join(rain_dims) + + # CRS + crs = CRS.from_user_input(projection) + cf_grid_mapping = crs.to_cf() + spatial_ref = ds.createVariable("projection", "i4") + for key, value in cf_grid_mapping.items(): + setattr(spatial_ref, key, value) + + # Global attributes + ds.Conventions = "CF-1.8" + ds.title = "" + ds.institution = "" + ds.references = "" + ds.comment = "" - # Handle NaNs in rain data and assign to variable - rain[np.isnan(rain)] = -1 - rain_var[0,:, :] = rain - - # Define spatial reference (CRS) - crs = CRS.from_user_input(projection) - cf_grid_mapping = crs.to_cf() - - # Create spatial reference variable - spatial_ref = ds.createVariable("projection", "i4") - for key, value in cf_grid_mapping.items(): - setattr(spatial_ref, key, value) - - # Add global attributes - ds.Conventions = "CF-1.7" - ds.title = "Rainfall data" - ds.institution = "Weather Radar New Zealand Ltd" - ds.references = "" - ds.comment = "" - return ds.close() import io import tempfile @@ -96,21 +129,28 @@ def write_netcdf(rain: np.ndarray, geo_data: dict, time: int): import numpy as np from pyproj import CRS -def write_netcdf_io(rain: np.ndarray, geo_data: dict, time: int) -> io.BytesIO: + +def make_netcdf_buffer(rain: np.ndarray, geo_data: dict, time: int) -> io.BytesIO: """ - Write a NetCDF file to a temporary file, read it into memory, and return a BytesIO buffer. + Make the BytesIO netcdf object that is needed for writing to GridFS database + Args: + rain (np.ndarray): array of rain rates in mm/h as float + geo_data (dict): spatial metadata + time (int): seconds since 1970-01-01T00:00:00Z + + Returns: + io.BytesIO: _description_ """ - - x = geo_data['x'] - y = geo_data['y'] - projection = geo_data.get('projection', 'EPSG:4326') + x = geo_data["x"] + y = geo_data["y"] + projection = geo_data.get("projection", "EPSG:4326") # Use NamedTemporaryFile to create a temp NetCDF file with tempfile.NamedTemporaryFile(suffix=".nc", delete=False) as tmp: tmp_path = tmp.name # Create NetCDF file on disk - ds = netCDF4.Dataset(tmp_path, mode='w', format='NETCDF4') + ds = netCDF4.Dataset(tmp_path, mode="w", format="NETCDF4") # Define dimensions ds.createDimension("y", len(y)) @@ -122,10 +162,10 @@ def write_netcdf_io(rain: np.ndarray, geo_data: dict, time: int) -> io.BytesIO: x_var = ds.createVariable("x", "f4", ("x",)) t_var = ds.createVariable("time", "i8", ("time",)) - # Rainfall variable + # Rainfall variable, + # Expects a float input array and the packing to i2 is done by the netCDF4 library rain_var = ds.createVariable( - "rainfall", "i2", ("time", "y", "x"), - zlib=True, complevel=5, fill_value=-1 + "rainfall", "i2", ("time", "y", "x"), zlib=True, complevel=5, fill_value=-1 ) rain_var.scale_factor = 0.1 rain_var.add_offset = 0.0 @@ -154,9 +194,9 @@ def write_netcdf_io(rain: np.ndarray, geo_data: dict, time: int) -> io.BytesIO: setattr(spatial_ref, key, value) # Global attributes - ds.Conventions = "CF-1.7" + ds.Conventions = "CF-1.8" ds.title = "Rainfall data" - ds.institution = "Weather Radar New Zealand Ltd" + ds.institution = "" ds.references = "" ds.comment = "" @@ -170,25 +210,48 @@ def write_netcdf_io(rain: np.ndarray, geo_data: dict, time: int) -> io.BytesIO: return io.BytesIO(nc_bytes) -def generate_geo_data(x, y, projection='EPSG:2193'): - """Generate geo-referencing data.""" - return { - "projection": projection, - "x": x, - "y": y, - "x1": np.round(x[0],decimals=0), - "x2": np.round(x[-1],decimals=0), - "y1": np.round(y[0],decimals=0), - "y2": np.round(y[-1],decimals=0), - "xpixelsize": np.round(x[1] - x[0],decimals=0), - "ypixelsize": np.round(y[1] - y[0],decimals=0), - "cartesian_unit": 'm', - "yorigin": 'lower', - "unit": 'mm/h', - "transform": None, - "threshold": 0.1, - "zerovalue": 0 - } +def generate_geo_dict(domain: dict) -> dict: + """ + Generate the pysteps geo-spatial metadata from a domain dictionary. + + Args: + domain (dict): pysteps_param domain dictionary + + Returns: + dict: pysteps geo-data dictionary, or {} if required keys are missing + """ + required_keys = {"n_cols", "n_rows", "p_size", "start_x", "start_y"} + missing = required_keys - domain.keys() + if missing: + # Missing keys, return empty dict + return {} + + ncols = domain.get("n_cols") + nrows = domain.get("n_rows") + psize = domain.get("p_size") + start_x = domain.get("start_x") + start_y = domain.get("start_y") + + x = [start_x + i * psize for i in range(ncols)] # type: ignore + y = [start_y + i * psize for i in range(nrows)] # type: ignore + + out_geo = {} + out_geo["x"] = x + out_geo["y"] = y + out_geo["xpixelsize"] = psize + out_geo["ypixelsize"] = psize + out_geo["x1"] = start_x + out_geo["y1"] = start_y + out_geo["x2"] = start_x + (ncols - 1) * psize # type: ignore + out_geo["y2"] = start_y + (nrows - 1) * psize # type: ignore + out_geo["projection"] = domain["projection"]["epsg"] + out_geo["cartesian_unit"] = ("m",) + out_geo["yorigin"] = ("lower",) + out_geo["unit"] = "mm/h" + out_geo["threshold"] = 0 + out_geo["transform"] = None + + return out_geo def read_nc(buffer: bytes): @@ -202,11 +265,19 @@ def read_nc(buffer: bytes): byte_stream = io.BytesIO(buffer) # Open the NetCDF dataset - with netCDF4.Dataset('inmemory', mode='r', memory=byte_stream.getvalue()) as ds: + with netCDF4.Dataset("inmemory", mode="r", memory=byte_stream.getvalue()) as ds: + # Extract geo-referencing data x = ds.variables["x"][:] y = ds.variables["y"][:] - geo_data = generate_geo_data(x, y) + + domain = {} + domain["ncols"] = len(x) + domain["nrows"] = len(y) + domain["psize"] = abs(x[1] - x[0]) + domain["start_x"] = x[0] + domain["start_y"] = y[0] + geo_data = generate_geo_dict(domain) # Convert timestamps to datetime valid_times = convert_timestamps_to_datetimes(ds.variables["time"][:]) @@ -214,10 +285,9 @@ def read_nc(buffer: bytes): # Extract rain rates rain_rate = ds.variables["rainfall"][:] - # Replace invalid data with NaN and squeeze dimensions + # Replace invalid data with NaN and squeeze dimensions of np.ndarray rain_rate = np.squeeze(rain_rate) rain_rate[rain_rate < 0] = np.nan - valid_times = np.squeeze(valid_times) return geo_data, valid_times, rain_rate @@ -228,91 +298,93 @@ def validate_keys(keys, mandatory_keys): if missing_keys: raise KeyError(f"Missing mandatory keys: {', '.join(missing_keys)}") -def make_nc_name_dt(out_file_name, name, out_product, valid_time, base_time, iens): - vtime = valid_time - if vtime.tzinfo is None: - vtime = vtime.replace(tzinfo=timezone.utc) - vtime_stamp = vtime.timestamp() +def make_nc_name( + domain: str, + prod: str, + valid_time: datetime.datetime, + base_time: Optional[datetime.datetime] = None, + ens: Optional[int] = None, + name_template: Optional[str] = None, +) -> str: + """ + Generate a unique name for a single rain field using a formatting template. - if base_time is not None: - btime = base_time - if btime.tzinfo is None: - btime = btime.replace(tzinfo=timezone.utc) - btime_stamp = btime.timestamp() - else: - btime_stamp = None - - fx_file_name = make_nc_name( - out_file_name, name, out_product, vtime_stamp, btime_stamp, iens) - return fx_file_name + Default templates: + Forecast products: "$D_$P_$V{%Y%m%dT%H%M%S}_$B{%Y%m%dT%H%M%S}_$E.nc" + QPE products: "$D_$P_$V{%Y%m%dT%H%M%S}.nc" + Where: + $D = Domain name + $P = Product name + $V = Valid time (with strftime format) + $B = Base time (with strftime format) + $E = Ensemble number (zero-padded 2-digit) -def make_nc_name(name_template: str, name: str, prod: str, valid_time: int, - base_time: Optional[int] = None, ens: Optional[int] = None) -> str: + Returns: + str: Unique NetCDF file name. """ - Generate a file name using a template. - - :param name_template: Template for the file name - :param name: Name of the domain - Mandatory - :param prod: Name of the product - Mandatory - :param valid_time: Valid time of the field - Mandatory - :param run_time: NWP run time - Optional - :param ens: Ensemble member - Optional - :return: String with the file name - """ - result = name_template - # Set up the valid time - vtime_info = datetime.fromtimestamp(valid_time, tz=timezone.utc) + if not isinstance(valid_time, datetime.datetime): + raise TypeError(f"valid_time must be datetime, got {type(valid_time)}") - # Set up the NWP base time if available - btime_info = datetime.fromtimestamp( - base_time, tz=timezone.utc) if base_time is not None else None + if base_time is not None and not isinstance(base_time, datetime.datetime): + raise TypeError(f"base_time must be datetime or None, got {type(base_time)}") - has_flag = True - while has_flag: - # Search for a flag + # Default template logic + if name_template is None: + name_template = "$D_$P_$V{%Y-%m-%dT%H:%M:%S}" + if base_time is not None: + name_template += "_$B{%Y-%m-%dT%H:%M:%S}" + if ens is not None: + name_template += "_$E" + name_template += ".nc" + + result = name_template + + # Ensure timezone-aware times + if valid_time.tzinfo is None: + valid_time = valid_time.replace(tzinfo=datetime.timezone.utc) + if base_time is not None and base_time.tzinfo is None: + base_time = base_time.replace(tzinfo=datetime.timezone.utc) + + # Replace flags + while "$" in result: flag_posn = result.find("$") if flag_posn == -1: - has_flag = False - else: - # Get the field type - f_type = result[flag_posn + 1] - - try: - # Add the valid and base times - if f_type in ['V', 'B']: - # Get the required format string - field_start = result.find("{", flag_posn + 1) - field_end = result.find("}", flag_posn + 1) - if field_start == -1 or field_end == -1: - raise ValueError(f"Invalid time format for flag '${ - f_type}' in template.") - - time_format = result[field_start + 1:field_end] - if f_type == 'V': - date_str = vtime_info.strftime(time_format) - elif f_type == 'B' and btime_info: - date_str = btime_info.strftime(time_format) - else: - date_str = "" - - # Replace the format field with the formatted time - result = result[:flag_posn] + \ - date_str + result[field_end + 1:] - elif f_type == 'P': - result = result[:flag_posn] + prod + result[flag_posn + 2:] - elif f_type == 'N': - result = result[:flag_posn] + name + result[flag_posn + 2:] - elif f_type == 'E' and ens is not None: - result = result[:flag_posn] + \ - f"{ens:02d}" + result[flag_posn + 2:] + break + f_type = result[flag_posn + 1] + + try: + if f_type in ["V", "B"]: + field_start = result.find("{", flag_posn + 1) + field_end = result.find("}", flag_posn + 1) + if field_start == -1 or field_end == -1: + raise ValueError( + f"Missing braces for format of '${f_type}' in template." + ) + + fmt = result[field_start + 1 : field_end] + if f_type == "V": + time_str = valid_time.strftime(fmt) + elif f_type == "B" and base_time is not None: + time_str = base_time.strftime(fmt) else: - raise ValueError(f"Unknown or unsupported flag '${ - f_type}' in template.") - except Exception as e: - raise ValueError(f"Error processing flag '${ - f_type}': {str(e)}") - - return result \ No newline at end of file + time_str = "" + + result = result[:flag_posn] + time_str + result[field_end + 1 :] + + elif f_type == "D": + result = result[:flag_posn] + domain + result[flag_posn + 2 :] + elif f_type == "P": + result = result[:flag_posn] + prod + result[flag_posn + 2 :] + elif f_type == "E" and ens is not None: + result = result[:flag_posn] + f"{ens:02d}" + result[flag_posn + 2 :] + else: + raise ValueError( + f"Unknown or unsupported flag '${f_type}' in template." + ) + except Exception as e: + raise ValueError(f"Error processing flag '${f_type}': {e}") + + return result From 3a055913c1afd6d059045428a10e51306b23b92d Mon Sep 17 00:00:00 2001 From: Alan Seed Date: Tue, 9 Sep 2025 23:07:23 +1000 Subject: [PATCH 08/12] remove nc_utils from param module --- pysteps/param/nc_utils.py | 107 -------------------------------------- 1 file changed, 107 deletions(-) delete mode 100644 pysteps/param/nc_utils.py diff --git a/pysteps/param/nc_utils.py b/pysteps/param/nc_utils.py deleted file mode 100644 index ff0604469..000000000 --- a/pysteps/param/nc_utils.py +++ /dev/null @@ -1,107 +0,0 @@ -import xarray as xr -import pandas as pd -import numpy as np -import datetime -import logging - - -def generate_geo_dict(domain): - ncols = domain.get("n_cols") - nrows = domain.get("n_rows") - psize = domain.get("p_size") - start_x = domain.get("start_x") - start_y = domain.get("start_y") - x = [start_x + i * psize for i in range(ncols)] - y = [start_y + i * psize for i in range(nrows)] - - out_geo = {} - out_geo["x"] = x - out_geo["y"] = y - out_geo["xpixelsize"] = psize - out_geo["ypixelsize"] = psize - out_geo["x1"] = start_x - out_geo["y1"] = start_y - out_geo["x2"] = start_x + (ncols - 1) * psize - out_geo["y2"] = start_y + (nrows - 1) * psize - out_geo["projection"] = domain["projection"]["epsg"] - out_geo["cartesian_unit"] = "m" - out_geo["yorigin"] = "lower" - out_geo["unit"] = "mm/h" - out_geo["threshold"] = 0 - out_geo["transform"] = None - - return out_geo - - -def generate_geo_dict_xy(x: np.ndarray, y: np.ndarray, epsg: str): - n_cols = x.size - n_rows = y.size - - out_geo = {} - out_geo["xpixelsize"] = (x[-1] - x[0]) / (n_cols - 1) - out_geo["ypixelsize"] = (y[-1] - y[0]) / (n_rows - 1) - out_geo["x1"] = x[0] - out_geo["x2"] = x[-1] - out_geo["y1"] = y[0] - out_geo["y2"] = y[-1] - out_geo["projection"] = epsg - out_geo["cartesian_unit"] = "m" - out_geo["yorigin"] = "lower" - out_geo["unit"] = "mm/h" - out_geo["threshold"] = 0 - out_geo["transform"] = None - - return out_geo - - -def read_qpe_netcdf(file_path): - """ - Read WRNZ QPE NetCDF file and return xarray Dataset of rain rate with: - - 'rain' variable in [time, yc, xc] order - - time as timezone-aware UTC datetimes - - EPSG:2193 (NZTM2000) projection info added using CF conventions - - Return None on error reading the file - - Assumes that the input file is rain rate in [t,y,x] order - """ - - try: - ds = xr.open_dataset(file_path, decode_cf=True, mask_and_scale=True) - ds.load() - - # Make the times timezone-aware UTC - time_values = ds["time"].values.astype("datetime64[ns]") - time_utc = pd.DatetimeIndex(time_values, tz=datetime.UTC) - ds["time"] = ("time", time_utc) - - # Rename - ds = ds.rename({"rainfall": "rain"}) - - # Define CF-compliant grid mapping for EPSG:2193 - crs = xr.DataArray( - 0, - attrs={ - "grid_mapping_name": "transverse_mercator", - "scale_factor_at_central_meridian": 0.9996, - "longitude_of_central_meridian": 173.0, - "latitude_of_projection_origin": 0.0, - "false_easting": 1600000.0, - "false_northing": 10000000.0, - "semi_major_axis": 6378137.0, - "inverse_flattening": 298.257222101, - "spatial_ref": "EPSG:2193", - }, - name="NZTM2000", - ) - - ds["NZTM2000"] = crs - ds["rain"].attrs["grid_mapping"] = "NZTM2000" - - ds = ds[["rain", "NZTM2000"]] - ds = ds.assign_coords(time=ds["time"], yc=ds["y"], xc=ds["x"]) - - return ds - - except (ValueError, OverflowError, TypeError) as e: - logging.warning(f"Failed to read {file_path}: {e}") - return None From c8b5d2e826dee40e4ce6ee25bc08ce736cf13ca1 Mon Sep 17 00:00:00 2001 From: Alan Seed Date: Tue, 9 Sep 2025 23:07:46 +1000 Subject: [PATCH 09/12] update code --- pysteps/param/shared_utils.py | 187 ++++++++++++++++------------------ 1 file changed, 90 insertions(+), 97 deletions(-) diff --git a/pysteps/param/shared_utils.py b/pysteps/param/shared_utils.py index 4da386e90..f37767045 100644 --- a/pysteps/param/shared_utils.py +++ b/pysteps/param/shared_utils.py @@ -20,7 +20,7 @@ from .stochastic_generator import gen_stoch_field, normalize_db_field from .rainfield_stats import correlation_length from .rainfield_stats import power_spectrum_1D -from .cascade_utils import lagr_auto_cor +from .cascade_utils import lagr_auto_cor, calculate_wavelengths def update_field( @@ -28,37 +28,31 @@ def update_field( oflow: np.ndarray, params: StepsParameters, bp_filter: dict, - config: dict, - dom: dict, + kmperpixel: float, + scale_break_km: float, + db_threshold: float, + zerovalue: float, ) -> np.ndarray: """ - Update a rainfall field using the parametric STEPS algorithm. - Assumes that the cascades list has the correct number of valid cascades + Generate a field conditioned on the cascades that are passed into the function Args: - cascades (list): List of 1 or 2 cascades for initial conditions - oflow(np.ndarray): Optical flow array - params (StepsParameters): Parameters for the update. - bp_filter: Bandpass filter dictionary returned by pysteps.cascade.bandpass_filters.filter_gaussian - config: The configuration dictionary - dom: the domain dictionary + cascades (list): cascades in reverse time order [t-1,t-2] + oflow (np.ndarray): _description_ + params (StepsParameters): _description_ + bp_filter (dict): _description_ + kmperpixel (float): _description_ + scale_break_km (float): _description_ + db_threshold (float): _description_ + zerovalue (float): _description_ Returns: - np.ndarray: Updated rainfall field in decibels (dB) of rain intensity + np.ndarray: _description_ """ - - ar_order = config["ar_order"] - n_levels = config["n_cascade_levels"] - n_rows = dom["n_rows"] - n_cols = dom["n_cols"] - - scale_break_km = config["scale_break"] - kmperpixel = config["kmperpixel"] - - rain_threshold = config["precip_threshold"] - db_threshold = 10 * np.log10(rain_threshold) - transformer = DBTransformer(rain_threshold) - zerovalue = transformer.zerovalue + ar_order = 2 # Assume AR(2) for now (It is the best option) + n_levels = cascades[0]["cascade_levels"].shape[0] + n_rows = cascades[0]["cascade_levels"].shape[1] + n_cols = cascades[0]["cascade_levels"].shape[2] # Set up the AR(2) parameters phi = np.zeros((n_levels, ar_order + 1)) @@ -129,10 +123,17 @@ def update_field( return norm_field -def zero_state(config, domain): - n_cascade_levels = config["n_cascade_levels"] - n_rows = domain["n_rows"] - n_cols = domain["n_cols"] +def zero_state(n_rows: int, n_cols: int, n_levels: int): + """_summary_ + Generate a dictionary with the data that is needed for the steps nowcast + Args: + n_rows (int): _description_ + n_cols (int): _description_ + n_levels (int): _description_ + + Returns: + _type_: _description_ + """ metadata_dict = { "transform": None, "threshold": None, @@ -142,9 +143,9 @@ def zero_state(config, domain): "wetted_area_ratio": float(0), } cascade_dict = { - "cascade_levels": np.zeros((n_cascade_levels, n_rows, n_cols)), - "means": np.zeros(n_cascade_levels), - "stds": np.zeros(n_cascade_levels), + "cascade_levels": np.zeros((n_levels, n_rows, n_cols)), + "means": np.zeros(n_levels), + "stds": np.zeros(n_levels), "domain": "spatial", "normalized": True, } @@ -153,7 +154,16 @@ def zero_state(config, domain): return state -def is_zero_state(state, tol=1e-6): +def is_zero_state(state, tol=1e-6) -> bool: + """ + Determine if the state is empty + Args: + state (_type_): _description_ + tol (_type_, optional): _description_. Defaults to 1e-6. + + Returns: + _type_: True or False + """ return abs(state["metadata"]["mean"]) < tol @@ -167,7 +177,7 @@ def is_zero_state(state, tol=1e-6): # corl_zero 1074.976508 188.058276 23.489147 -def qc_params(ens_df, config): +def qc_params(ens_df: pd.DataFrame, config: dict, domain: dict) -> pd.DataFrame: """ Apply QC to the 'param' column in the ensemble DataFrame. The DataFrame is assumed to have 'valid_time' as the index. @@ -182,9 +192,10 @@ def qc_params(ens_df, config): "rain_fraction", "beta_1", "beta_2", + "corl_zero", ] - var_lower = [2.81, 1.30, 0.0, -2.73, -4.01] - var_upper = [9.50, 5.00, 1.0, -2.05, -2.32] + var_lower = [2.81, 1.30, 0.0, -2.73, -4.01, 60] + var_upper = [9.50, 5.00, 1.0, -2.05, -2.32, 1000] qc_df = ens_df.copy(deep=True) qc_dict = {} @@ -205,24 +216,6 @@ def qc_params(ens_df, config): ) qc_dict[var] = np.clip(model.fittedvalues, var_lower[iv], var_upper[iv]) - # Extract correlation length thresholds from config - corl_pvals = config["dynamic_scaling"]["cor_len_pvals"] - corl_min = min(corl_pvals) - corl_max = max(corl_pvals) - corl_def = corl_pvals[1] # median - - # Prepare and smooth corl_zero - corl_list = [] - for idx in qc_df.index: - corl = qc_df.at[idx, "param"].get("corl_zero", corl_def) - corl = corl_def if corl is None else max(corl_min, min(corl, corl_max)) - corl_list.append(corl) - - model = SimpleExpSmoothing(corl_list, initialization_method="estimated").fit( - smoothing_level=0.1, optimized=False - ) - qc_dict["corl_zero"] = model.fittedvalues - # Assign smoothed parameters and compute lags for i, idx in enumerate(qc_df.index): param = copy.deepcopy(qc_df.at[idx, "param"]) @@ -236,8 +229,8 @@ def qc_params(ens_df, config): setattr(param, var, qc_dict[var][i]) param.corl_zero = qc_dict["corl_zero"][i] - # Compute lag-1 and lag-2 for this correlation length - lags, _ = calc_auto_corls(config, param.corl_zero) + # Compute lag-1 and lag-2 for each level in the cascade + lags, _ = calc_auto_cors(config, domain, param.corl_zero) param.lag_1 = list(lags[:, 0]) param.lag_2 = list(lags[:, 1]) @@ -266,7 +259,8 @@ def blend_param(qpe_params, nwp_params, param_names, weight): def blend_parameters( - config: dict[str, object], + config: dict, + domain: dict, blend_base_time: datetime.datetime, nwp_param_df: pd.DataFrame, rad_param: StepsParameters, @@ -312,7 +306,7 @@ def default_weight_fn(lag_sec: float) -> float: updated = blend_param(rad_param, clean_original, blended_param_names, weight) # Compute lag-1 and lag-2 for this correlation length - lags, _ = calc_auto_corls(config, updated.corl_zero) + lags = calc_auto_cors(config, domain, updated.corl_zero) updated.lag_1 = list(lags[:, 0]) updated.lag_2 = list(lags[:, 1]) @@ -384,42 +378,6 @@ def fill_param_gaps( return pd.DataFrame(records) -def calc_auto_corls(config: dict, T_ref: float) -> tuple[np.ndarray, np.ndarray]: - """ - Compute lag-1 and lag-2 autocorrelations for each cascade level using a power-law model. - - Args: - config (dict): Configuration dictionary with 'pysteps.timestep' (in seconds) - and 'dynamic_scaling' parameters. - T_ref (float): Reference correlation length T(t, L) at the largest scale (in minutes). - - Returns: - np.ndarray: Array of shape (n_levels, 2) with [lag1, lag2] for each level. - np.ndarray: Array of corelation lengths per level - """ - dt_seconds = config["timestep"] - dt_mins = dt_seconds / 60.0 - - ds_config = config.get("dynamic_scaling", {}) - scales = ds_config["central_wave_lengths"] - ht = ds_config["space_time_exponent"] - a = ds_config["lag2_constants"] - b = ds_config["lag2_exponents"] - - L = scales[0] - T_levels = [T_ref * (l / L) ** ht for l in scales] - - lags = np.empty((len(scales), 2), dtype=np.float32) - for ia, T_l in enumerate(T_levels): - pl_lag1 = np.exp(-dt_mins / T_l) - pl_lag2 = a[ia] * (pl_lag1 ** b[ia]) - lags[ia, 0] = pl_lag1 - lags[ia, 1] = pl_lag2 - - levels = np.array(T_levels) - return lags, levels - - def fit_auto_cors( clen: float, alpha: float, @@ -521,7 +479,7 @@ def _obj(l1: float) -> float: return lag1, lag2 -def calc_corls(scales, czero, ht): +def calc_cor_lengths(scales: np.ndarray, czero: float, ht: float) -> list[float]: # Power law function for correlation length corls = [czero] lzero = scales[0] @@ -531,14 +489,49 @@ def calc_corls(scales, czero, ht): return corls +def calc_auto_cors(config: dict, domain: dict, czero: float) -> np.ndarray: + timestep = config["timestep"] + dt = timestep / 60 + n_levels = config["n_cascade_levels"] + ht = config.get("ht", 0.96) + alpha = config.get("alpha", 2.83) + n_cols = domain["n_cols"] + n_rows = domain["n_rows"] + p_size = domain["p_size"] + domain_size = max(n_rows, n_cols) + d = 1000 / p_size + scales = calculate_wavelengths(n_levels, domain_size, d) + cor_lens = calc_cor_lengths(scales, czero, ht) + lags = np.zeros((n_levels, 2)) + for ilev in range(n_levels): + r1, r2 = fit_auto_cors(cor_lens[ilev], alpha, dt) # type: ignore + lags[ilev, 0] = r1 + lags[ilev, 1] = r2 + return lags + + def calculate_parameters( db_field: np.ndarray, - cascades: dict, + cascades: list[dict], oflow: np.ndarray, scale_break: float, zero_value: float, dt: int, -): +) -> StepsParameters: + """ + Calculate the steps parameters + + Args: + db_field (np.ndarray): + cascades (list[dict]): list of cascade dictionaries in order t-2, t-1, t0 + oflow (np.ndarray): Optical flow from t-1 to t0 + scale_break (float): Location of the power spectum scale break in pixels + zero_value (float): zero value for the dBR transformation + dt (int): time step in minutes + + Returns: + StepsParameters: Dataclass with the steps parameters + """ p_dict = {} # Probability distribution moments From da105e88db24129277d781053adf91a38e7d1436 Mon Sep 17 00:00:00 2001 From: Alan Seed Date: Tue, 9 Sep 2025 23:08:04 +1000 Subject: [PATCH 10/12] update code --- pysteps/param/steps_params.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pysteps/param/steps_params.py b/pysteps/param/steps_params.py index aaa9a56c8..fcdfcde70 100644 --- a/pysteps/param/steps_params.py +++ b/pysteps/param/steps_params.py @@ -1,6 +1,7 @@ from dataclasses import dataclass, field import datetime +# These are representative values from a 250 km in Auckland N.Z # 95% 50% 5% # nonzero_mean_db 6.883147 4.590082 2.815397 # nonzero_stdev_db 3.793680 2.489131 1.298552 From 90fe9d9e8449889b4b429ae6ae0bf7a262ac5c70 Mon Sep 17 00:00:00 2001 From: Alan Seed Date: Wed, 10 Sep 2025 00:45:52 +1000 Subject: [PATCH 11/12] moved transformer --- pysteps/param/__init__.py | 1 - pysteps/param/broken_line.py | 124 ----------------------- pysteps/param/transformer.py | 187 ----------------------------------- 3 files changed, 312 deletions(-) delete mode 100644 pysteps/param/broken_line.py delete mode 100644 pysteps/param/transformer.py diff --git a/pysteps/param/__init__.py b/pysteps/param/__init__.py index 5eb5d7412..7cb6c050b 100644 --- a/pysteps/param/__init__.py +++ b/pysteps/param/__init__.py @@ -19,4 +19,3 @@ fit_auto_cors, calculate_parameters, ) -from .transformer import DBTransformer diff --git a/pysteps/param/broken_line.py b/pysteps/param/broken_line.py deleted file mode 100644 index 101f3ddd2..000000000 --- a/pysteps/param/broken_line.py +++ /dev/null @@ -1,124 +0,0 @@ - -import numpy as np -from typing import Optional - - -def broken_line(rain_mean: float, rain_std: float, time_step: int, duration: int, - h: Optional[float] = 0.60, q: Optional[float] = 0.85, - a_zero_min: Optional[float] = 1500, transform: Optional[bool] = True): - """ - Generate a time series of rainfall using the broken line model. - Based on Seed et al. (2000), WRR. - - Args: - rain_mean (float): Mean of time series (must be > 0) - rain_std (float): Standard deviation of time series (must be > 0) - time_step (int): Time step in minutes (must be > 0) - duration (int): Duration of time series in minutes (must be > time_step) - h (float): Scaling exponent (0 < h < 1) - q (float): Scale change ratio between lines (0 < q < 1) - a_zero_min (float): Maximum time scale in minutes (must be > 0) - transform (bool): Use log transformation to generate the time series - - Returns: - np.ndarray: Rainfall time series of specified length, or None on error - """ - - # Validate input parameters - if not isinstance(rain_mean, (float, int)) or rain_mean <= 0: - print("Error: rain_mean must be a positive number.") - return None - if not isinstance(rain_std, (float, int)) or rain_std <= 0: - print("Error: rain_std must be a positive number.") - return None - if not isinstance(time_step, int) or time_step <= 0: - print("Error: time_step must be a positive integer.") - return None - if not isinstance(duration, int) or duration <= time_step: - print("Error: duration must be an integer greater than time_step.") - return None - if not isinstance(h, (float, int)) or not (0 < h < 1): - print("Error: h must be a float in the range (0,1).") - return None - if not isinstance(q, (float, int)) or not (0 < q < 1): - print("Error: q must be a float in the range (0,1).") - return None - if not isinstance(a_zero_min, (float, int)) or a_zero_min <= 0: - print("Error: a_zero_min must be a positive number.") - return None - - # Number of time steps to generate - length = duration // time_step # Ensure integer division - - # Calculate the lognormal mean and variance - if transform: - ratio = rain_std / rain_mean - bl_mean = np.log(rain_mean) - 0.5 * np.log(ratio**2 + 1) - bl_var = np.log(ratio**2 + 1) - else: - bl_mean = rain_mean - bl_var = rain_std ** 2.0 - - # Compute number of broken lines - a_zero = a_zero_min / time_step - N = max(1, int(np.log(1.0 / a_zero) / np.log(q)) + 1) # Prevents N=0 - - # Compute variance at the outermost scale - var_zero = bl_var * (1 - q**h) / (1 - q**(N * h)) - - # Initialize the time series with mean - model = np.full(length, bl_mean) - - # Add broken lines at different scales - for p in range(N): - break_step = a_zero * q**p - line_stdev = np.sqrt(var_zero * q**(p * h)) - line = make_line(line_stdev, break_step, length) - model += line - - # Transform back to rainfall space if needed - if transform: - rain = np.exp(model) - return rain - else: - return model - - -def make_line(std_dev, break_step, length): - """ - Generate a piecewise linear process with random breakpoints. - - Args: - std_dev (float): Standard deviation for generating y-values. - break_step (float): Distance between breakpoints. - length (int): Length of the output array. - - Returns: - np.ndarray: Interpolated line of given length. - """ - - # Generate random breakpoints - rng = np.random.default_rng(None) - - if break_step < 1: - y = rng.normal(0, std_dev, length) # Scaled correctly - return y - - # Number of breakpoints - n_points = 3 + int(length / break_step) - y = rng.normal(0, 1.5 * std_dev, n_points) # Scaled correctly - - # Generate x-coordinates with random offset - offset = rng.uniform(-break_step, 0) - x = [offset + break_step*ia for ia in range(n_points)] - - # Interpolate onto full time series - x_out = np.arange(length) - line = np.interp(x_out, x, y) - - # Normalize the standard deviation - line_std = np.std(line) - if line_std > 0: - line = (line - np.mean(line)) * (std_dev / line_std) - - return line diff --git a/pysteps/param/transformer.py b/pysteps/param/transformer.py deleted file mode 100644 index 7e5325c02..000000000 --- a/pysteps/param/transformer.py +++ /dev/null @@ -1,187 +0,0 @@ -import numpy as np -import scipy.stats as scipy_stats -from scipy.interpolate import interp1d -from typing import Optional - - -class BaseTransformer: - def __init__(self, threshold: float = 0.5, zerovalue: Optional[float] = None): - self.threshold = threshold - self.zerovalue = zerovalue - self.metadata = {} - - def transform(self, R: np.ndarray) -> np.ndarray: - raise NotImplementedError - - def inverse_transform(self, R: np.ndarray) -> np.ndarray: - raise NotImplementedError - - def get_metadata(self) -> dict: - return self.metadata.copy() - - -class DBTransformer(BaseTransformer): - """ - DBTransformer applies a thresholded dB transform to rain rate fields. - - Parameters: - threshold (float): Rain rate threshold (in mm/h). Values below this are set to `zerovalue` in dB. - zerovalue (Optional[float]): Value in dB space to assign below-threshold pixels. If None, defaults to log10(threshold) - 0.1 - """ - - def __init__(self, threshold: float = 0.5, zerovalue: Optional[float] = None): - super().__init__(threshold, zerovalue) - threshold_db = 10.0 * np.log10(self.threshold) - - if self.zerovalue is None: - self.zerovalue = threshold_db - 0.1 - - self.metadata = { - "transform": "dB", - "threshold": self.threshold, # stored in mm/h - "zerovalue": self.zerovalue, # stored in dB - } - - def transform(self, R: np.ndarray) -> np.ndarray: - R = R.copy() - mask = R < self.threshold - R[~mask] = 10.0 * np.log10(R[~mask]) - R[mask] = self.zerovalue - return R - - def inverse_transform(self, R: np.ndarray) -> np.ndarray: - R = R.copy() - R = 10.0 ** (R / 10.0) - R[R < self.threshold] = 0 - return R - - -class BoxCoxTransformer(BaseTransformer): - def __init__(self, Lambda: float = 0.0, **kwargs): - super().__init__(**kwargs) - self.Lambda = Lambda - - def transform(self, R: np.ndarray) -> np.ndarray: - R = R.copy() - mask = R < self.threshold - - if self.Lambda == 0.0: - R[~mask] = np.log(R[~mask]) - tval = np.log(self.threshold) - else: - R[~mask] = (R[~mask] ** self.Lambda - 1) / self.Lambda - tval = (self.threshold**self.Lambda - 1) / self.Lambda - - if self.zerovalue is None: - self.zerovalue = tval - 1 - - R[mask] = self.zerovalue - - self.metadata = { - "transform": "BoxCox", - "lambda": self.Lambda, - "threshold": tval, - "zerovalue": self.zerovalue, - } - return R - - def inverse_transform(self, R: np.ndarray) -> np.ndarray: - R = R.copy() - if self.Lambda == 0.0: - R = np.exp(R) - else: - R = np.exp(np.log(self.Lambda * R + 1) / self.Lambda) - - threshold_inv = ( - np.exp(np.log(self.Lambda * self.metadata["threshold"] + 1) / self.Lambda) - if self.Lambda != 0.0 - else np.exp(self.metadata["threshold"]) - ) - - R[R < threshold_inv] = self.metadata["zerovalue"] - self.metadata["transform"] = None - return R - - -class NQTransformer(BaseTransformer): - def __init__(self, a: float = 0.0, **kwargs): - super().__init__(**kwargs) - self.a = a - self._inverse_interp = None - - def transform(self, R: np.ndarray) -> np.ndarray: - R = R.copy() - shape = R.shape - R = R.ravel() - mask = ~np.isnan(R) - R_ = R[mask] - - n = R_.size - Rpp = (np.arange(n) + 1 - self.a) / (n + 1 - 2 * self.a) - Rqn = scipy_stats.norm.ppf(Rpp) - R_sorted = R_[np.argsort(R_)] - R_trans = np.interp(R_, R_sorted, Rqn) - - self.zerovalue = np.min(R_) - R_trans[R_ == self.zerovalue] = 0 - - self._inverse_interp = interp1d( - Rqn, - R_sorted, - bounds_error=False, - fill_value=(float(R_sorted.min()), float(R_sorted.max())), # type: ignore - ) - - R[mask] = R_trans - R = R.reshape(shape) - - self.metadata = { - "transform": "NQT", - "threshold": R_trans[R_trans > 0].min(), - "zerovalue": 0, - } - return R - - def inverse_transform(self, R: np.ndarray) -> np.ndarray: - if self._inverse_interp is None: - raise RuntimeError("Must call transform() before inverse_transform()") - - R = R.copy() - shape = R.shape - R = R.ravel() - mask = ~np.isnan(R) - R[mask] = self._inverse_interp(R[mask]) - R = R.reshape(shape) - - self.metadata["transform"] = None - return R - - -class SqrtTransformer(BaseTransformer): - def transform(self, R: np.ndarray) -> np.ndarray: - R = np.sqrt(R) - self.metadata = { - "transform": "sqrt", - "threshold": np.sqrt(self.threshold), - "zerovalue": np.sqrt(self.zerovalue) if self.zerovalue else 0.0, - } - return R - - def inverse_transform(self, R: np.ndarray) -> np.ndarray: - R = R**2 - self.metadata["transform"] = None - return R - - -def get_transformer(name: str, **kwargs) -> BaseTransformer: - name = name.lower() - if name == "boxcox": - return BoxCoxTransformer(**kwargs) - elif name == "db": - return DBTransformer(**kwargs) - elif name == "nqt": - return NQTransformer(**kwargs) - elif name == "sqrt": - return SqrtTransformer(**kwargs) - else: - raise ValueError(f"Unknown transformer type: {name}") From c1c884cdab91fb729542dcf5d674949ab6606734 Mon Sep 17 00:00:00 2001 From: Alan Seed Date: Wed, 10 Sep 2025 17:28:47 +1000 Subject: [PATCH 12/12] remove param directory --- pysteps/param/__init__.py | 21 - pysteps/param/cascade_utils.py | 88 ---- pysteps/param/rainfield_stats.py | 509 ----------------------- pysteps/param/shared_utils.py | 578 -------------------------- pysteps/param/steps_params.py | 125 ------ pysteps/param/stochastic_generator.py | 225 ---------- 6 files changed, 1546 deletions(-) delete mode 100644 pysteps/param/__init__.py delete mode 100644 pysteps/param/cascade_utils.py delete mode 100644 pysteps/param/rainfield_stats.py delete mode 100644 pysteps/param/shared_utils.py delete mode 100644 pysteps/param/steps_params.py delete mode 100644 pysteps/param/stochastic_generator.py diff --git a/pysteps/param/__init__.py b/pysteps/param/__init__.py deleted file mode 100644 index 7cb6c050b..000000000 --- a/pysteps/param/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -from .steps_params import StepsParameters -from .rainfield_stats import ( - RainfieldStats, - compute_field_parameters, - compute_field_stats, - power_spectrum_1D, - correlation_length, - power_law_acor, -) -from .stochastic_generator import gen_stoch_field, normalize_db_field, pl_filter -from .cascade_utils import calculate_wavelengths, lagr_auto_cor -from .shared_utils import ( - qc_params, - update_field, - blend_parameters, - zero_state, - is_zero_state, - calc_auto_cors, - fit_auto_cors, - calculate_parameters, -) diff --git a/pysteps/param/cascade_utils.py b/pysteps/param/cascade_utils.py deleted file mode 100644 index 97e41104b..000000000 --- a/pysteps/param/cascade_utils.py +++ /dev/null @@ -1,88 +0,0 @@ -import numpy as np -from pysteps import extrapolation - - -def calculate_wavelengths(n_levels: int, domain_size: float, d: float = 1.0): - """ - Compute the central wavelengths (in km) for each cascade level. - - Parameters - ---------- - n_levels : int - Number of cascade levels. - domain_size : int or float - The larger of the two spatial dimensions of the domain in pixels. - d : float - Sample frequency in pixels per km. Default is 1. - - Returns - ------- - wavelengths_km : np.ndarray - Central wavelengths in km for each cascade level (length = n_levels). - """ - # Compute q - q = pow(0.5 * domain_size, 1.0 / n_levels) - - # Compute central wavenumbers (in grid units) - r = [(pow(q, k - 1), pow(q, k)) for k in range(1, n_levels + 1)] - central_wavenumbers = np.array([0.5 * (r0 + r1) for r0, r1 in r]) - - # Convert to frequency - central_freqs = central_wavenumbers / domain_size - central_freqs[0] = 1.0 / domain_size - central_freqs[-1] = 0.5 # Nyquist limit - - # Convert wavelength to km, d is pixels per km - central_freqs = central_freqs * d - central_wavelengths_km = 1.0 / central_freqs - return central_wavelengths_km - - -def lagr_auto_cor(data: np.ndarray, oflow: np.ndarray): - """ - Generate the Lagrangian auto correlations for STEPS cascades. - - Args: - data (np.ndarray): [T, L, M, N] where: - - T = ar_order + 1 (number of time steps) - - L = number of cascade levels - - M, N = spatial dimensions. - oflow (np.ndarray): [2, M, N] Optical flow vectors. - - Returns: - np.ndarray: Autocorrelation coefficients of shape (L, ar_order). - """ - ar_order = 2 - if data.shape[0] < (ar_order + 1): - raise ValueError( - f"Insufficient time steps. Expected at least {ar_order + 1}, got {data.shape[0]}." - ) - - n_cascade_levels = data.shape[1] - extrapolation_method = extrapolation.get_method("semilagrangian") - - autocorrelation_coefficients = np.full((n_cascade_levels, ar_order), np.nan) - - for level in range(n_cascade_levels): - lag_1 = extrapolation_method(data[-2, level], oflow, 1)[0] - lag_1 = np.where(np.isfinite(lag_1), lag_1, 0) - - data_t = np.where(np.isfinite(data[-1, level]), data[-1, level], 0) - if np.std(lag_1) > 1e-1 and np.std(data_t) > 1e-1: - autocorrelation_coefficients[level, 0] = np.corrcoef( - lag_1.flatten(), data_t.flatten() - )[0, 1] - - if ar_order == 2: - lag_2 = extrapolation_method(data[-3, level], oflow, 1)[0] - lag_2 = np.where(np.isfinite(lag_2), lag_2, 0) - - lag_1 = extrapolation_method(lag_2, oflow, 1)[0] - lag_1 = np.where(np.isfinite(lag_1), lag_1, 0) - - if np.std(lag_1) > 1e-1 and np.std(data_t) > 1e-1: - autocorrelation_coefficients[level, 1] = np.corrcoef( - lag_1.flatten(), data_t.flatten() - )[0, 1] - - return autocorrelation_coefficients diff --git a/pysteps/param/rainfield_stats.py b/pysteps/param/rainfield_stats.py deleted file mode 100644 index a8486b703..000000000 --- a/pysteps/param/rainfield_stats.py +++ /dev/null @@ -1,509 +0,0 @@ -# Contains: RainfieldStats (dataclass, from_dict, to_dict), compute_field_parameters, compute_field_stats -""" -Functions to calculate rainfield statistics -""" -from typing import Optional, Tuple, Dict, List, Any -import datetime -from dataclasses import dataclass -import xarray as xr -from scipy.optimize import curve_fit -import numpy as np - -MAX_RAIN_RATE = 250 -N_BINS = 200 - - -@dataclass -class RainfieldStats: - domain: Optional[str] = None - product: Optional[str] = None - valid_time: Optional[datetime.datetime] = None - base_time: Optional[datetime.datetime] = None - ensemble: Optional[int] = None - filename: Optional[str] = None - - transform: Optional[str] = None - zerovalue: Optional[float] = None - threshold: Optional[float] = None - kmperpixel: Optional[float] = None - - mean_db: Optional[float] = None - stdev_db: Optional[float] = None - nonzero_mean_db: Optional[float] = None - nonzero_stdev_db: Optional[float] = None - - mean_rain: Optional[float] = None - stdev_rain: Optional[float] = None - rain_fraction: Optional[float] = None - nonzero_mean_rain: Optional[float] = None - nonzero_stdev_rain: Optional[float] = None - - psd: Optional[List[float]] = None - psd_bins: Optional[List[float]] = None - c1: Optional[float] = None - c2: Optional[float] = None - scale_break: Optional[float] = None - beta_1: Optional[float] = None - beta_2: Optional[float] = None - corl_zero: Optional[float] = None - - cdf: Optional[List[float]] = None - cdf_bins: Optional[List[float]] = None - - cascade_stds: Optional[List[float]] = None - cascade_means: Optional[List[float]] = None - cascade_lag1: Optional[List[float]] = None - cascade_lag2: Optional[List[float]] = None - cascade_corl: Optional[List[float]] = None - - def get(self, key: str, default: Any = None) -> Any: - """Mimic dict.get() for selected attributes.""" - return getattr(self, key, default) - - def calc_corl(self): - """Populate the correlation lengths using lag1 and lag2 values.""" - - # Make sure that we have defined auto-correlations - if self.cascade_lag1 is None or self.cascade_lag2 is None: - self.cascade_corl = None - return - - # We have defined auto-correlations so set up the correlation length array - n_levels = len(self.cascade_lag1) - self.cascade_corl = [np.nan] * n_levels - for ilev in range(n_levels): - lag1 = self.cascade_lag1[ilev] - lag2 = self.cascade_lag2[ilev] - self.cascade_corl[ilev] = correlation_length(lag1, lag2) - self.corl_zero = self.cascade_corl[0] - - def calc_acor(self, config) -> None: - T_ref = self.corl_zero - if T_ref is None or np.isnan(T_ref): - T_ref = config["dynamic_scaling"]["cor_len_pvals"][1] - - acor, corl = power_law_acor(config, T_ref) - self.cascade_corl = [float(x) for x in corl] - self.cascade_lag1 = [float(x) for x in acor[:, 0]] - self.cascade_lag2 = [float(x) for x in acor[:, 1]] - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "RainfieldStats": - dbr = data.get("dbr_stats", {}) - rain = data.get("rain_stats", {}) - pspec = data.get("power_spectrum", {}) - model = pspec.get("model", {}) if pspec else {} - cdf_data = data.get("cdf", {}) - cascade = data.get("cascade", {}) - meta = data.get("metadata", {}) - - kwargs = { - "product": meta.get("product"), - "valid_time": meta.get("valid_time"), - "base_time": meta.get("base_time"), - "ensemble": meta.get("ensemble"), - "filename": meta.get("filename"), - "kmperpixel": meta.get("kmperpixel"), - "transform": dbr.get("transform"), - "zerovalue": dbr.get("zerovalue"), - "threshold": dbr.get("threshold"), - "nonzero_mean_db": dbr.get("nonzero_mean"), - "nonzero_stdev_db": dbr.get("nonzero_stdev"), - "rain_fraction": dbr.get("nonzero_fraction"), - "mean_db": dbr.get("mean"), - "stdev_db": dbr.get("stdev"), - "nonzero_mean_rain": rain.get("nonzero_mean"), - "nonzero_stdev_rain": rain.get("nonzero_stdev"), - "mean_rain": rain.get("mean"), - "stdev_rain": rain.get("stdev"), - "psd": pspec.get("psd"), - "psd_bins": pspec.get("psd_bins"), - "c1": model.get("c1"), - "c2": model.get("c2"), - "scale_break": model.get("scale_break"), - "beta_1": model.get("beta_1"), - "beta_2": model.get("beta_2"), - "cdf": cdf_data.get("cdf"), - "cdf_bins": cdf_data.get("cdf_bins"), - "corl_zero": cascade.get("corl_zero"), - "cascade_stds": cascade.get("stds"), - "cascade_means": cascade.get("means"), - "cascade_lag1": cascade.get("lag1"), - "cascade_lag2": cascade.get("lag2"), - } - - # Make sure the times are UTC - ttime = kwargs["valid_time"] - if ttime is not None and ttime.tzinfo is None: - ttime = ttime.replace(tzinfo=datetime.timezone.utc) - kwargs["valid_time"] = ttime - - ttime = kwargs["base_time"] - if ttime is not None and ttime.tzinfo is None: - ttime = ttime.replace(tzinfo=datetime.timezone.utc) - kwargs["base_time"] = ttime - - # Add cascade_corl explicitly, since it's constructed dynamically - lag1_list = cascade.get("lag1", []) - kwargs["cascade_corl"] = [np.nan] * len(lag1_list) if lag1_list else None - - return cls(**kwargs) - - def to_dict(self) -> Dict[str, Any]: - return { - "dbr_stats": { - "transform": self.transform, - "zerovalue": self.zerovalue, - "threshold": self.threshold, - "nonzero_mean": self.nonzero_mean_db, - "nonzero_stdev": self.nonzero_stdev_db, - "nonzero_fraction": self.rain_fraction, - "mean": self.mean_db, - "stdev": self.stdev_db, - }, - "rain_stats": { - "nonzero_mean": self.nonzero_mean_rain, - "nonzero_stdev": self.nonzero_stdev_rain, - "nonzero_fraction": self.rain_fraction, # assume same as dbr_stats - "mean": self.mean_rain, - "stdev": self.stdev_rain, - "transform": None, - "zerovalue": 0, - "threshold": 0.1, - }, - "power_spectrum": { - "psd": self.psd, - "psd_bins": self.psd_bins, - "model": ( - { - "beta_1": self.beta_1, - "beta_2": self.beta_2, - "c1": self.c1, - "c2": self.c2, - "scale_break": self.scale_break, - } - if any( - x is not None - for x in [ - self.beta_1, - self.beta_2, - self.c1, - self.c2, - self.scale_break, - ] - ) - else None - ), - }, - "cdf": { - "cdf": self.cdf, - "cdf_bins": self.cdf_bins, - }, - "cascade": ( - { - "corl_zero": self.corl_zero, - "stds": self.cascade_stds, - "means": self.cascade_means, - "lag1": self.cascade_lag1, - "lag2": self.cascade_lag2, - "corl": self.cascade_corl, - } - if self.cascade_stds is not None - else None - ), - "metadata": { - "domain": self.domain, - "product": self.product, - "valid_time": self.valid_time, - "base_time": self.base_time, - "ensemble": self.ensemble, - "filename": self.filename, - "kmperpixel": self.kmperpixel, - }, - } - - -def compute_field_parameters( - db_data: np.ndarray, db_metadata: dict, scalebreak: Optional[float] = None -): - """ - Compute STEPS parameters for the dB transformed rainfall field - - Args: - db_data (np.ndarray): 2D field of dB-transformed rain. - db_metadata (dict): pysteps metadata dictionary. - - Returns: - dict: Dictionary containing STEPS parameters. - """ - - ps_dataset, ps_model = power_spectrum_1D(db_data, scalebreak) - if ps_dataset is not None: - power_spectrum = { - "psd": ps_dataset.psd.values.tolist(), - "psd_bins": ps_dataset.psd_bins.values.tolist(), - "model": ps_model, - } - else: - power_spectrum = {} - - # Compute cumulative probability distribution - cdf_dataset = prob_dist(db_data, db_metadata) - cdf = { - "cdf": cdf_dataset.cdf.values.tolist(), - "cdf_bins": cdf_dataset.cdf_bins.values.tolist(), - } - - # Store parameters in a dictionary - field_params = {"power_spectrum": power_spectrum, "cdf": cdf} - return field_params - - -def power_spectrum_1D( - field: np.ndarray, scale_break: Optional[float] = None -) -> Tuple[Optional[xr.Dataset], Optional[Dict[str, float]]]: - """ - Calculate the 1D isotropic power spectrum and fit a power law model. - - Args: - field (np.ndarray): 2D input field in [rows, columns] order. - scale_break (float, optional): Scale break in pixel units. If None, fit single line. - - Returns: - ps_dataset (xarray.Dataset): 1D isotropic power spectrum in dB. - model_params (dict): Dictionary with model parameters: beta_1, beta_2, c1, c2, scale_break - """ - min_stdev = 0.1 - mean = np.nanmean(field) - stdev = np.nanstd(field) - if stdev < min_stdev: - return None, None - - norm_field = (field - mean) / stdev - np.nan_to_num(norm_field, copy=False) - - field_fft = np.fft.rfft2(norm_field) - power_spectrum = np.abs(field_fft) ** 2 - - freq_x = np.fft.fftfreq(field.shape[1]) - freq_y = np.fft.fftfreq(field.shape[0]) - freq_r = np.sqrt(freq_x[:, None] ** 2 + freq_y[None, :] ** 2) - freq_r = freq_r[: field.shape[0] // 2, : field.shape[1] // 2] - power_spectrum = power_spectrum[: field.shape[0] // 2, : field.shape[1] // 2] - - n_bins = power_spectrum.shape[0] - bins = np.logspace( - np.log10(freq_r.min() + 1 / n_bins), np.log10(freq_r.max()), num=n_bins - ) - bin_centers = (bins[:-1] + bins[1:]) / 2 - power_1d = np.zeros(len(bin_centers)) - - for i in range(len(bins) - 1): - mask = (freq_r >= bins[i]) & (freq_r < bins[i + 1]) - power_1d[i] = np.nanmean(power_spectrum[mask]) if np.any(mask) else np.nan - - valid = (bin_centers > 0) & (~np.isnan(power_1d)) - bin_centers = bin_centers[valid] - power_1d = power_1d[valid] - - if len(bin_centers) == 0: - return None, None - - log_x = 10 * np.log10(bin_centers) - log_y = 10 * np.log10(power_1d) - - start_idx = 2 - end_idx = np.searchsorted(log_x, -4.0) - - model_params = {} - - if scale_break is None: - - def str_line(X, m, c): - return m * X + c - - popt, _ = curve_fit( - str_line, log_x[start_idx:end_idx], log_y[start_idx:end_idx] - ) - beta_1, c1 = popt - beta_2 = None - c2 = None - sb_log = None - else: - sb_freq = 1.0 / scale_break - sb_log = 10 * np.log10(sb_freq) - - def piecewise_linear(x, m1, m2, c1): - c2 = (m1 - m2) * sb_log + c1 - return np.where(x <= sb_log, m1 * x + c1, m2 * x + c2) - - popt, _ = curve_fit( - piecewise_linear, log_x[start_idx:end_idx], log_y[start_idx:end_idx] - ) - beta_1, beta_2, c1 = popt - c2 = (beta_1 - beta_2) * sb_log + c1 - - ps_dataset = xr.Dataset( - {"psd": (["bin"], log_y)}, - coords={"psd_bins": (["bin"], log_x)}, - attrs={"description": "1-D Isotropic power spectrum", "units": "dB"}, - ) - - model_params = { - "beta_1": float(beta_1), - "beta_2": float(beta_2) if beta_2 is not None else None, - "c1": float(c1), - "c2": float(c2) if c2 is not None else None, - "scale_break": float(scale_break) if scale_break is not None else None, - } - - return ps_dataset, model_params - - -def prob_dist(data: np.ndarray, metadata: dict): - """ - Calculate the cumulative probability distribution for rain > threshold for dB field - - Args: - data (np.ndarray): 2D field of dB-transformed rain. - metadata (dict): pysteps metadata dictionary. - - Returns: - tuple: - - xarray Dataset containing the cumulative probability distribution and bin edges - - fraction of field with rain > threshold (float) - """ - - rain_mask = data > metadata["zerovalue"] - - # Compute cumulative probability distribution - min_db = metadata["threshold"] - max_db = 10 * np.log10(MAX_RAIN_RATE) - bin_edges = np.linspace(min_db, max_db, N_BINS) - - # Histogram of rain values - hist, _ = np.histogram(data[rain_mask], bins=bin_edges, density=True) - - # Compute cumulative distribution - cumulative_distr = np.cumsum(hist) / np.sum(hist) - - # Create an xarray Dataset to store both cumulative distribution and bin edges - cdf_dataset = xr.Dataset( - { - "cdf": (["bin"], cumulative_distr), - }, - coords={ - # bin_edges[:-1] to match the histogram bins - "cdf_bins": (["bin"], bin_edges[:-1]), - }, - attrs={ - "description": "Cumulative probability distribution of rain rates", - "units": "dB", - }, - ) - - return cdf_dataset - - -def compute_field_stats(data, geodata): - nonzero_mask = data >= geodata["threshold"] - nonzero_mean = np.mean(data[nonzero_mask]) if np.any(nonzero_mask) else np.nan - nonzero_stdev = np.std(data[nonzero_mask]) if np.any(nonzero_mask) else np.nan - nonzero_frac = np.sum(nonzero_mask) / data.size - mean_rain = np.nanmean(data) - stdev_rain = np.nanstd(data) - - rain_stats = { - "nonzero_mean": float(nonzero_mean) if nonzero_mean is not None else None, - "nonzero_stdev": float(nonzero_stdev) if nonzero_stdev is not None else None, - "nonzero_fraction": float(nonzero_frac) if nonzero_frac is not None else None, - "mean": float(mean_rain) if mean_rain is not None else None, - "stdev": float(stdev_rain) if stdev_rain is not None else None, - "transform": geodata["transform"], - "zerovalue": geodata["zerovalue"], - "threshold": geodata["threshold"], - } - return rain_stats - - -def is_stationary(phi1, phi2): - return abs(phi2) < 1 and (phi1 + phi2) < 1 and (phi2 - phi1) < 1 - - -def correlation_length( - lag1: float, lag2: float, dx=10, tol=1e-4, max_lag=1000 -) -> float: - """ - Calculate the correlation length in minutes assuming AR(2) process - Args: - lag1 (float): Lag 1 auto-correltion - lag2 (float): Lag 2 auto-correlation - dx (int, optional): time step between lag1 & 2 in minutes. Defaults to 10. - tol (float, optional): _description_. Defaults to 1e-4. - max_lag (int, optional): _description_. Defaults to 1000. - - Returns: - corl (float): Correlation length in minutes - np.nan on error - """ - if lag1 is None or lag2 is None: - return np.nan - - A = np.array([[1.0, lag1], [lag1, 1.0]]) - b = np.array([lag1, lag2]) - - try: - phi = np.linalg.solve(A, b) - except np.linalg.LinAlgError: - return np.nan - - phi1, phi2 = phi - if not is_stationary(phi1, phi2): - return np.nan - - rho_vals = [1.0, lag1, lag2] - for _ in range(3, max_lag): - next_rho = phi1 * rho_vals[-1] + phi2 * rho_vals[-2] - if abs(next_rho) < tol: - break - rho_vals.append(next_rho) - corl = np.trapz(rho_vals, dx=dx) - return float(corl) - - -def power_law_acor( - config: Dict[str, Any], T_ref: float -) -> tuple[np.ndarray, np.ndarray]: - """ - Compute lag-1 and lag-2 autocorrelations for each cascade level using a power-law model. - - Args: - config (dict): Configuration dictionary with 'pysteps.timestep' (in seconds) - and 'dynamic_scaling' parameters. - T_ref (float): Reference correlation length T(t, L) at the largest scale (in minutes). - - Returns: - np.ndarray: Array of shape (n_levels, 2) with [lag1, lag2] for each level. - np.ndarray: Array of corelation lengths per level - """ - dt_seconds = config["pysteps"]["timestep"] - dt_mins = dt_seconds / 60.0 - - ds_config = config.get("dynamic_scaling", {}) - scales = ds_config["central_wave_lengths"] - ht = ds_config["space_time_exponent"] - a = ds_config["lag2_constants"] - b = ds_config["lag2_exponents"] - - L = scales[0] - T_levels = [T_ref * (l / L) ** ht for l in scales] - - lags = np.empty((len(scales), 2), dtype=np.float32) - for ia, T_l in enumerate(T_levels): - pl_lag1 = np.exp(-dt_mins / T_l) - pl_lag2 = a[ia] * (pl_lag1 ** b[ia]) - lags[ia, 0] = pl_lag1 - lags[ia, 1] = pl_lag2 - - levels = np.array(T_levels) - return lags, levels diff --git a/pysteps/param/shared_utils.py b/pysteps/param/shared_utils.py deleted file mode 100644 index f37767045..000000000 --- a/pysteps/param/shared_utils.py +++ /dev/null @@ -1,578 +0,0 @@ -from collections.abc import Callable - -import datetime -import copy -import logging -import numpy as np -import pandas as pd - -from statsmodels.tsa.api import SimpleExpSmoothing - -from pysteps.cascade.decomposition import decomposition_fft, recompose_fft -from pysteps.timeseries.autoregression import ( - adjust_lag2_corrcoef2, - estimate_ar_params_yw, -) -from pysteps import extrapolation - -from pysteps.utils.transformer import DBTransformer -from .steps_params import StepsParameters -from .stochastic_generator import gen_stoch_field, normalize_db_field -from .rainfield_stats import correlation_length -from .rainfield_stats import power_spectrum_1D -from .cascade_utils import lagr_auto_cor, calculate_wavelengths - - -def update_field( - cascades: list, - oflow: np.ndarray, - params: StepsParameters, - bp_filter: dict, - kmperpixel: float, - scale_break_km: float, - db_threshold: float, - zerovalue: float, -) -> np.ndarray: - """ - Generate a field conditioned on the cascades that are passed into the function - - Args: - cascades (list): cascades in reverse time order [t-1,t-2] - oflow (np.ndarray): _description_ - params (StepsParameters): _description_ - bp_filter (dict): _description_ - kmperpixel (float): _description_ - scale_break_km (float): _description_ - db_threshold (float): _description_ - zerovalue (float): _description_ - - Returns: - np.ndarray: _description_ - """ - ar_order = 2 # Assume AR(2) for now (It is the best option) - n_levels = cascades[0]["cascade_levels"].shape[0] - n_rows = cascades[0]["cascade_levels"].shape[1] - n_cols = cascades[0]["cascade_levels"].shape[2] - - # Set up the AR(2) parameters - phi = np.zeros((n_levels, ar_order + 1)) - for ilev in range(n_levels): - gamma_1 = params.lag_1[ilev] - gamma_2 = params.lag_2[ilev] - if ar_order == 2: - gamma_2 = adjust_lag2_corrcoef2(gamma_1, gamma_2) - phi[ilev] = estimate_ar_params_yw([gamma_1, gamma_2]) - else: - phi[ilev] = estimate_ar_params_yw([gamma_1]) - - # Generate the noise field and cascade - noise_field = gen_stoch_field( - params, n_cols, n_rows, kmperpixel, scale_break_km, db_threshold - ) - noise_cascade = decomposition_fft( - noise_field, bp_filter, compute_stats=True, normalize=True - ) - - # Update the cascade - extrapolation_method = extrapolation.get_method("semilagrangian") - lag_0 = np.zeros((n_levels, n_rows, n_cols)) - - if ar_order == 2: - lag_1 = copy.deepcopy(cascades[0]["cascade_levels"]) - lag_2 = copy.deepcopy(cascades[1]["cascade_levels"]) - - for ilev in range(n_levels): - adv_lag2 = extrapolation_method(lag_2[ilev], oflow, 2, outval=0)[1] - adv_lag1 = extrapolation_method(lag_1[ilev], oflow, 1, outval=0)[0] - lag_0[ilev] = ( - phi[ilev, 0] * adv_lag1 - + phi[ilev, 1] * adv_lag2 - + phi[ilev, 2] * noise_cascade["cascade_levels"][ilev] - ) - - else: - lag_1 = copy.deepcopy(cascades[0]["cascade_levels"]) - for ilev in range(n_levels): - adv_lag1 = extrapolation_method(lag_1[ilev], oflow, 1, outval=0)[0] - lag_0[ilev] = ( - phi[ilev, 0] * adv_lag1 - + phi[ilev, 1] * noise_cascade["cascade_levels"][ilev] - ) - - # Make sure we have mean = 0, stdev = 1 - lev_mean = np.mean(lag_0) - lev_stdev = np.std(lag_0) - if lev_stdev > 1e-1: - lag_0 = (lag_0 - lev_mean) / lev_stdev - - # Recompose the cascade into a single field - updated_cascade = {} - updated_cascade["domain"] = "spatial" - updated_cascade["normalized"] = True - updated_cascade["compact_output"] = False - updated_cascade["cascade_levels"] = lag_0.copy() - - # Use the noise cascade level stds - updated_cascade["means"] = noise_cascade["means"].copy() - updated_cascade["stds"] = noise_cascade["stds"].copy() - gen_field = recompose_fft(updated_cascade) - - # Normalise the field to have the expected conditional mean and variance - norm_field = normalize_db_field(gen_field, params, db_threshold, zerovalue) - - return norm_field - - -def zero_state(n_rows: int, n_cols: int, n_levels: int): - """_summary_ - Generate a dictionary with the data that is needed for the steps nowcast - Args: - n_rows (int): _description_ - n_cols (int): _description_ - n_levels (int): _description_ - - Returns: - _type_: _description_ - """ - metadata_dict = { - "transform": None, - "threshold": None, - "zerovalue": None, - "mean": float(0), - "std_dev": float(0), - "wetted_area_ratio": float(0), - } - cascade_dict = { - "cascade_levels": np.zeros((n_levels, n_rows, n_cols)), - "means": np.zeros(n_levels), - "stds": np.zeros(n_levels), - "domain": "spatial", - "normalized": True, - } - oflow = np.zeros((2, n_rows, n_cols)) - state = {"cascade": cascade_dict, "optical_flow": oflow, "metadata": metadata_dict} - return state - - -def is_zero_state(state, tol=1e-6) -> bool: - """ - Determine if the state is empty - Args: - state (_type_): _description_ - tol (_type_, optional): _description_. Defaults to 1e-6. - - Returns: - _type_: True or False - """ - return abs(state["metadata"]["mean"]) < tol - - -# climatology of the parameters for radar QPE 9000 sets of parameters in Auckland -# 95% 50% 5% -# nonzero_mean_db 6.883147 4.590082 2.815397 -# nonzero_stdev_db 3.793680 2.489131 1.298552 -# rain_fraction 0.447717 0.048889 0.008789 -# beta_1 -0.452957 -1.681647 -2.726216 -# beta_2 -2.322891 -3.251342 -4.009131 -# corl_zero 1074.976508 188.058276 23.489147 - - -def qc_params(ens_df: pd.DataFrame, config: dict, domain: dict) -> pd.DataFrame: - """ - Apply QC to the 'param' column in the ensemble DataFrame. - The DataFrame is assumed to have 'valid_time' as the index. - - Smooth corl_zero using exponential smoothing and recompute cascade autocorrelations. - Clamp smoothed parameters to climatological bounds. - Returns a deep-copied DataFrame with corrected parameters. - """ - var_list = [ - "nonzero_mean_db", - "nonzero_stdev_db", - "rain_fraction", - "beta_1", - "beta_2", - "corl_zero", - ] - var_lower = [2.81, 1.30, 0.0, -2.73, -4.01, 60] - var_upper = [9.50, 5.00, 1.0, -2.05, -2.32, 1000] - - qc_df = ens_df.copy(deep=True) - qc_dict = {} - - # Smooth each variable and clamp to bounds - for iv, var in enumerate(var_list): - x_list = [ - ( - np.nan - if qc_df.at[idx, "param"].get(var) is None - else qc_df.at[idx, "param"].get(var) - ) - for idx in qc_df.index - ] - - model = SimpleExpSmoothing(x_list, initialization_method="estimated").fit( - smoothing_level=0.10, optimized=False - ) - qc_dict[var] = np.clip(model.fittedvalues, var_lower[iv], var_upper[iv]) - - # Assign smoothed parameters and compute lags - for i, idx in enumerate(qc_df.index): - param = copy.deepcopy(qc_df.at[idx, "param"]) - - # Ensure valid spectral slope order - if qc_dict["beta_2"][i] > qc_dict["beta_1"][i]: - qc_dict["beta_2"][i] = qc_dict["beta_1"][i] - - # Assign smoothed & clamped values - for var in var_list: - setattr(param, var, qc_dict[var][i]) - param.corl_zero = qc_dict["corl_zero"][i] - - # Compute lag-1 and lag-2 for each level in the cascade - lags, _ = calc_auto_cors(config, domain, param.corl_zero) - param.lag_1 = list(lags[:, 0]) - param.lag_2 = list(lags[:, 1]) - - # Save updated object - qc_df.at[idx, "param"] = param - - return qc_df - - -def blend_param(qpe_params, nwp_params, param_names, weight): - for pname in param_names: - - qval = getattr(qpe_params, pname, None) - nval = getattr(nwp_params, pname, None) - if isinstance(qval, (int, float)) and isinstance(nval, (int, float)): - setattr(nwp_params, pname, weight * qval + (1 - weight) * nval) - elif ( - isinstance(qval, list) and isinstance(nval, list) and len(qval) == len(nval) - ): - setattr( - nwp_params, - pname, - [weight * q + (1 - weight) * n for q, n in zip(qval, nval)], - ) - return nwp_params - - -def blend_parameters( - config: dict, - domain: dict, - blend_base_time: datetime.datetime, - nwp_param_df: pd.DataFrame, - rad_param: StepsParameters, - weight_fn: Callable[[float], float] | None = None, -) -> pd.DataFrame: - """ - Function to blend the radar and NWP parameters - - Args: - config (dict): Configuration dictionary - blend_base_time (datetime.datetime): Time of the radar parameter set - nwp_param_df (pd.DataFrame): Dataframe of valid_time and parameters, - with valid_time as index and of type datetime.datetime - rad_param (StochasticRainParameters): Parameter object with radar parameters - weight_fn (Optional[Callable[[float], float]], optional): _description_. Defaults to None. - - Returns: - pd.DataFrame: _description_ - """ - - def default_weight_fn(lag_sec: float) -> float: - return np.exp(-((lag_sec / 10800) ** 2)) # 3h Gaussian - - if weight_fn is None: - weight_fn = default_weight_fn - - blended_param_names = [ - "nonzero_mean_db", - "nonzero_stdev_db", - "rain_fraction", - "beta_1", - "beta_2", - "corl_zero", - ] - blended_df = copy.deepcopy(nwp_param_df) - for vtime in blended_df.index: - lag_sec = (vtime - blend_base_time).total_seconds() - weight = weight_fn(lag_sec) - - # Select the parameter object for this vtime and blend - original = blended_df.loc[vtime, "param"] - clean_original = copy.deepcopy(original) - updated = blend_param(rad_param, clean_original, blended_param_names, weight) - - # Compute lag-1 and lag-2 for this correlation length - lags = calc_auto_cors(config, domain, updated.corl_zero) - updated.lag_1 = list(lags[:, 0]) - updated.lag_2 = list(lags[:, 1]) - - blended_df.loc[vtime, "param"] = updated - - return blended_df - - -def fill_param_gaps( - ens_df: pd.DataFrame, forecast_times: list[datetime.datetime] -) -> pd.DataFrame: - """ - Fill gaps in the time series of parameters with the most recent *original* observation - if the gap is smaller than a threshold. - - Assumes that all the parameters have the same domain, product, base_time, ensemble. - - Args: - ens_df (pd.DataFrame): DataFrame with columns 'valid_time' and 'param'. - forecast_times (list): List of datetime.datetime in UTC. - - Returns: - pd.DataFrame: DataFrame with gaps filled. - """ - max_gap = datetime.timedelta(hours=6) - - ens_df = ens_df.copy() - ens_df["valid_time"] = pd.to_datetime(ens_df["valid_time"], utc=True) - ens_df = ens_df.sort_values("valid_time").reset_index(drop=True) - - filled_map = dict(zip(ens_df["valid_time"], ens_df["param"])) - original_times = set(ens_df["valid_time"]) - - # Extract default metadata - first_param = ens_df.iloc[0].at["param"] - def_metadata_base = first_param.metadata.copy() - - for vtime in forecast_times: - if vtime in filled_map: - continue - - metadata = def_metadata_base.copy() - metadata["valid_time"] = vtime - - # Find the nearest valid time - if original_times: - nearest_time = min(original_times, key=lambda t: abs(t - vtime)) - gap = abs(nearest_time - vtime) - - if gap <= max_gap: - logging.debug( - f"Filling {vtime} with params from nearest time {nearest_time} (gap = {gap})" - ) - def_param = copy.deepcopy(filled_map[nearest_time]) - def_param.metadata = metadata - def_param.rain_fraction = 0 - else: - logging.debug( - f"Nearest gap too large to fill for {vtime}, using default" - ) - def_param = StepsParameters(metadata=metadata) - else: - logging.debug(f"No valid parameter found near {vtime}, using default") - def_param = StepsParameters(metadata=metadata) - - filled_map[vtime] = def_param - - records = [{"valid_time": t, "param": p} for t, p in sorted(filled_map.items())] - return pd.DataFrame(records) - - -def fit_auto_cors( - clen: float, - alpha: float, - d_mins: int, - *, - allow_negative: bool = False, - return_diagnostics: bool = False, -): - """ - Find lag1, lag2 (with lag2 = lag1**alpha) such that - correlation_length(lag1, lag2, d_mins) ~= clen. - - Parameters - ---------- - clen : float - Target correlation length (minutes), must be > 0. - alpha : float - Exponent linking lag2 and lag1 via lag2 = lag1**alpha. - d_mins : int - Time step between lag1 and lag2 (minutes). - allow_negative : bool, optional - If True, search lag1 in (-1, 1). Otherwise restrict to (0, 1). - return_diagnostics : bool, optional - If True, also return achieved correlation length and absolute error. - - Returns - ------- - lag1 : float - lag2 : float - (achieved_clen, abs_error) : tuple[float, float], only if return_diagnostics=True - """ - if not np.isfinite(clen) or clen <= 0: - raise ValueError("clen must be a positive, finite number.") - if not np.isfinite(alpha): - raise ValueError("alpha must be finite.") - if not np.isfinite(d_mins) or d_mins <= 0: - raise ValueError("d_mins must be a positive, finite number.") - - # Stability / search bounds for lag1 - eps = 1e-3 - lo, hi = (-0.999999, 0.999999) if allow_negative else (eps, 0.999999) - - # Objective: squared error on correlation length - def _obj(l1: float) -> float: - # Quick rejection of out-of-bounds - if not (lo < l1 < hi): - return np.inf - l2 = l1**alpha - - # Keep |lag2| < 1 as well to stay in a stable region - if not (abs(l2) < 1.0): - return np.inf - - # Make sure that we have a valid lag1, lag2 combination - l2 = adjust_lag2_corrcoef2(l1, l2) - - c = correlation_length(l1, l2, d_mins) - if not np.isfinite(c): - return np.inf - return (c - clen) ** 2 - - # Try SciPy first - lag1 = None - try: - from scipy.optimize import minimize_scalar # type: ignore - - res = minimize_scalar( - _obj, bounds=(lo, hi), method="bounded", options={"xatol": 1e-10} - ) - lag1 = float(res.x) - except Exception: - # Pure NumPy golden-section fallback - phi = (1.0 + np.sqrt(5.0)) / 2.0 - a, b = lo, hi - c = b - (b - a) / phi - d = a + (b - a) / phi - fc = _obj(c) - fd = _obj(d) - # Max ~100 iterations gives ~1e-8 bracket typically - for _ in range(100): - if fc < fd: - b, d, fd = d, c, fc - c = b - (b - a) / phi - fc = _obj(c) - else: - a, c, fc = c, d, fd - d = a + (b - a) / phi - fd = _obj(d) - if (b - a) < 1e-10: - break - lag1 = float((a + b) / 2.0) - - lag2 = float(lag1**alpha) - achieved = float(correlation_length(lag1, lag2, d_mins)) - err = abs(achieved - clen) - - if return_diagnostics: - return lag1, lag2, (achieved, err) - return lag1, lag2 - - -def calc_cor_lengths(scales: np.ndarray, czero: float, ht: float) -> list[float]: - # Power law function for correlation length - corls = [czero] - lzero = scales[0] - for scale in scales[1:]: - corl = czero * (scale / lzero) ** ht - corls.append(corl) - return corls - - -def calc_auto_cors(config: dict, domain: dict, czero: float) -> np.ndarray: - timestep = config["timestep"] - dt = timestep / 60 - n_levels = config["n_cascade_levels"] - ht = config.get("ht", 0.96) - alpha = config.get("alpha", 2.83) - n_cols = domain["n_cols"] - n_rows = domain["n_rows"] - p_size = domain["p_size"] - domain_size = max(n_rows, n_cols) - d = 1000 / p_size - scales = calculate_wavelengths(n_levels, domain_size, d) - cor_lens = calc_cor_lengths(scales, czero, ht) - lags = np.zeros((n_levels, 2)) - for ilev in range(n_levels): - r1, r2 = fit_auto_cors(cor_lens[ilev], alpha, dt) # type: ignore - lags[ilev, 0] = r1 - lags[ilev, 1] = r2 - return lags - - -def calculate_parameters( - db_field: np.ndarray, - cascades: list[dict], - oflow: np.ndarray, - scale_break: float, - zero_value: float, - dt: int, -) -> StepsParameters: - """ - Calculate the steps parameters - - Args: - db_field (np.ndarray): - cascades (list[dict]): list of cascade dictionaries in order t-2, t-1, t0 - oflow (np.ndarray): Optical flow from t-1 to t0 - scale_break (float): Location of the power spectum scale break in pixels - zero_value (float): zero value for the dBR transformation - dt (int): time step in minutes - - Returns: - StepsParameters: Dataclass with the steps parameters - """ - p_dict = {} - - # Probability distribution moments - nonzero_mask = db_field > zero_value - p_dict["nonzero_mean_db"] = ( - np.mean(db_field[nonzero_mask]) if np.any(nonzero_mask) else np.nan - ) - p_dict["nonzero_stdev_db"] = ( - np.std(db_field[nonzero_mask]) if np.any(nonzero_mask) else np.nan - ) - p_dict["rain_fraction"] = np.sum(nonzero_mask) / db_field.size - - # Power spectrum slopes - _, ps_model = power_spectrum_1D(db_field, scale_break) - if ps_model: - p_dict["beta_1"] = ps_model.get("beta_1", -2.05) - p_dict["beta_2"] = ps_model.get("beta_2", -3.2) - else: - p_dict["beta_1"] = -2.05 - p_dict["beta_2"] = -3.2 - - # Stack the (k,m,n) arrays in order t-2, t-1, t0 to get (t,k,m,n) array - data = [] - for ia in range(3): - data.append(cascades[ia]["cascade_levels"]) - data = np.stack(data) - a_corls = lagr_auto_cor(data, oflow) - n_levels = a_corls.shape[0] - lag_1 = [] - lag_2 = [] - clens = [] - for ilag in range(n_levels): - r1 = float(a_corls[ilag][0]) - r2 = float(a_corls[ilag][1]) - clen = correlation_length(r1, r2, dt) - lag_1.append(r1) - lag_2.append(r2) - clens.append(clen) - - p_dict["lag_1"] = lag_1 - p_dict["lag_2"] = lag_2 - p_dict["corl_zero"] = clens[0] - - return StepsParameters.from_dict(p_dict) diff --git a/pysteps/param/steps_params.py b/pysteps/param/steps_params.py deleted file mode 100644 index fcdfcde70..000000000 --- a/pysteps/param/steps_params.py +++ /dev/null @@ -1,125 +0,0 @@ -from dataclasses import dataclass, field -import datetime - -# These are representative values from a 250 km in Auckland N.Z -# 95% 50% 5% -# nonzero_mean_db 6.883147 4.590082 2.815397 -# nonzero_stdev_db 3.793680 2.489131 1.298552 -# rain_fraction 0.447717 0.048889 0.008789 -# beta_1 -0.452957 -1.681647 -2.726216 -# beta_2 -2.322891 -3.251342 -4.009131 -# corl_zero 1074.976508 188.058276 23.489147 - - -@dataclass -class StepsParameters: - metadata: dict - - # STEPS parameters with defaults for light rain - nonzero_mean_db: float = 2.81 - nonzero_stdev_db: float = 1.3 - rain_fraction: float = 0 - beta_1: float = -2.05 - beta_2: float = -3.2 - corl_zero: float = 180 - - # Auto-correlation lists - lag_1: list[float] = field(default_factory=list) - lag_2: list[float] = field(default_factory=list) - - # Required metadata keys - _required_metadata_keys = { - "domain", - "product", - "valid_time", - "base_time", - "ensemble", - } - - def get(self, key: str, default=None): - """Mimic dict.get(). Check metadata first, then top-level attributes.""" - if key in self.metadata: - value = self.metadata.get(key) - if ( - value is None - and key in self._required_metadata_keys - and default is None - ): - raise KeyError(f"Required metadata key '{key}' is missing or None.") - return value if value is not None else default - else: - return getattr(self, key, default) - - def set_metadata(self, key: str, value): - """Set a metadata key/value pair and validate if required.""" - self.metadata[key] = value - if key in self._required_metadata_keys and value is None: - raise ValueError(f"Required metadata key '{key}' cannot be None.") - - def validate(self): - """Raise ValueError if any required field is missing or None.""" - for key in self._required_metadata_keys: - if key not in self.metadata or self.metadata[key] is None: - raise ValueError(f"Missing required metadata field: '{key}'") - - @classmethod - def from_dict(cls, data: dict): - """Create a StepsParameters object from a dictionary.""" - - def ensure_utc(dt): - if dt is None: - return None - if isinstance(dt, str): - dt = datetime.datetime.fromisoformat(dt) - if dt.tzinfo is None: - return dt.replace(tzinfo=datetime.timezone.utc) - return dt.astimezone(datetime.timezone.utc) - - meta = data.get("metadata", {}) - if meta is not None: - metadata = { - "domain": meta.get("domain"), - "product": meta.get("product"), - "valid_time": ensure_utc(meta.get("valid_time")), - "base_time": ensure_utc(meta.get("base_time")), - "ensemble": meta.get("ensemble"), - } - else: - metadata = {} - - return cls( - metadata=metadata, - nonzero_mean_db=data.get("nonzero_mean_db", 2.81), - nonzero_stdev_db=data.get("nonzero_stdev_db", 1.3), - rain_fraction=data.get("rain_fraction", 0), - beta_1=data.get("beta_1", -2.05), - beta_2=data.get("beta_2", -3.2), - corl_zero=data.get("corl_zero", 180), - lag_1=data.get("lag_1", []), - lag_2=data.get("lag_2", []), - ) - - def to_dict(self): - """Convert the object into a dictionary for MongoDB or JSON.""" - if self.metadata is not None: - metadata = { - "domain": self.metadata["domain"], - "product": self.metadata["product"], - "valid_time": self.metadata["valid_time"], - "base_time": self.metadata["base_time"], - "ensemble": self.metadata["ensemble"], - } - else: - metadata = {} - - return { - "metadata": metadata, - "nonzero_mean_db": self.nonzero_mean_db, - "nonzero_stdev_db": self.nonzero_stdev_db, - "rain_fraction": self.rain_fraction, - "beta_1": self.beta_1, - "beta_2": self.beta_2, - "corl_zero": self.corl_zero, - "lag_1": self.lag_1, - "lag_2": self.lag_2, - } diff --git a/pysteps/param/stochastic_generator.py b/pysteps/param/stochastic_generator.py deleted file mode 100644 index 22cc68117..000000000 --- a/pysteps/param/stochastic_generator.py +++ /dev/null @@ -1,225 +0,0 @@ -# Contains: gen_stoch_field, normalize_db_field, pl_filter -from typing import Optional -import numpy as np -from scipy import interpolate, stats -from .steps_params import StepsParameters - - -def gen_stoch_field( - steps_params: StepsParameters, - nx: int, - ny: int, - pixel_size: float, - scale_break: float, - threshold: float, -): - """ - Generate a rain field with normal distribution and a power law power spectrum - Args: - steps_params (StepsParameters): The dataclass with the steps parameters - nx (int): x dimension of the output field - ny (int): y dimension of the output field - kmperpixel (float): pixel size - scale_break (float): scale break in km - threshold (float): rain threshold in db - - Returns: - np.ndarray: Output field with shape (ny,nx) - """ - - beta_1 = steps_params.beta_1 - beta_2 = steps_params.beta_2 - - # generate uniform random numbers in the range 0,1 - y = np.random.uniform(low=0, high=1, size=(ny, nx)) - - # Power law filter the field - fft = np.fft.fft2(y, (ny, nx)) - filter = pl_filter(beta_1, nx, ny, pixel_size, beta_2, scale_break) - out_fft = fft * filter - out_field = np.fft.ifft2(out_fft).real - - nbins = 500 - eps = 0.001 - - res = stats.cumfreq(out_field, numbins=nbins) - bins = [res.lowerlimit + ia * res.binsize for ia in range(1 + res.cumcount.size)] - count = res.cumcount / res.cumcount[nbins - 1] - - # find the threshold value for this non-rain probability - rain_bin = 0 - for ia in range(nbins): - if count[ia] <= 1 - steps_params.rain_fraction: - rain_bin = ia - else: - break - rain_threshold = bins[rain_bin] - - # Shift the data to have the correct probability > 0 - norm_data = out_field - rain_threshold - - # Now we need to transform the "raining" samples to have the desired distribution - rain_mask = norm_data > threshold - rain_obs = norm_data[rain_mask] - rain_res = stats.cumfreq(rain_obs, numbins=nbins) - rain_bins = [ - rain_res.lowerlimit + ia * rain_res.binsize - for ia in range(1 + rain_res.cumcount.size) - ] - rain_cdf = rain_res.cumcount / rain_res.cumcount[nbins - 1] - - # rain_bins are the bin edges; use bin centers for interpolation - bin_centers = 0.5 * (np.array(rain_bins[:-1]) + np.array(rain_bins[1:])) - - # Step 1: Build LUT: map empirical CDF → target normal quantiles - # Make sure rain_cdf values are in (0,1) to avoid issues with extreme tails - rain_cdf_clipped = np.clip(rain_cdf, eps, 1 - eps) - - # Map rain_cdf quantiles to corresponding values in the target normal distribution - target_mu = steps_params.nonzero_mean_db - target_sigma = steps_params.nonzero_stdev_db - normal_values = stats.norm.ppf(rain_cdf_clipped, loc=target_mu, scale=target_sigma) - - # Create interpolation function from observed rain values to target normal values - cdf_transform = interpolate.interp1d( - bin_centers, - normal_values, - kind="linear", - bounds_error=False, - fill_value=(normal_values[0], normal_values[-1]), # type: ignore - ) - - # Transform pdf of the raining pixels - norm_data[rain_mask] = cdf_transform(norm_data[rain_mask]) - - return norm_data - - -def normalize_db_field(data, params, threshold, zerovalue): - if params.rain_fraction < 0.025: - return np.full_like(data, zerovalue) - - nbins = 500 - eps = 0.0001 - - res = stats.cumfreq(data, numbins=nbins) - bins = [res.lowerlimit + ia * res.binsize for ia in range(1 + res.cumcount.size)] - count = res.cumcount / res.cumcount[nbins - 1] - - # find the threshold value for this non-rain probability - rain_bin = 0 - for ia in range(nbins): - if count[ia] <= 1 - params.rain_fraction: - rain_bin = ia - else: - break - rain_threshold = bins[rain_bin + 1] - - # Shift the data to have the correct probability of rain - norm_data = data + (threshold - rain_threshold) - - # Now we need to transform the raining samples to have the desired distribution - # Get the sample distribution - rain_mask = norm_data > threshold - rain_obs = norm_data[rain_mask] - rain_res = stats.cumfreq(rain_obs, numbins=nbins) - - rain_bins = [ - rain_res.lowerlimit + ia * rain_res.binsize - for ia in range(1 + rain_res.cumcount.size) - ] - rain_cdf = rain_res.cumcount / rain_res.cumcount[nbins - 1] - - # rain_bins are the bin edges; use bin centers for interpolation - bin_centers = 0.5 * (np.array(rain_bins[:-1]) + np.array(rain_bins[1:])) - - # Step 1: Build LUT: map empirical CDF → target normal quantiles - # Make sure rain_cdf values are in (0,1) to avoid issues with extreme tails - rain_cdf_clipped = np.clip(rain_cdf, eps, 1 - eps) - - # Map rain_cdf quantiles to corresponding values in the target normal distribution - # We need to reduce the bias in the output fields - target_mu = params.nonzero_mean_db - target_sigma = params.nonzero_stdev_db - normal_values = stats.norm.ppf(rain_cdf_clipped, loc=target_mu, scale=target_sigma) - - # Create interpolation function from observed rain values to target normal values - fill_value = (normal_values[0], normal_values[-1]) - cdf_transform = interpolate.interp1d( - bin_centers, - normal_values, - kind="linear", - bounds_error=False, - fill_value=fill_value, # type: ignore - ) - - # Transform raining pixels - norm_data[rain_mask] = cdf_transform(norm_data[rain_mask]) - - # Check if we have nans and return zerovalue if yes - has_nan = np.isnan(norm_data).any() - if has_nan: - return np.full_like(data, zerovalue) - else: - return norm_data - - -def pl_filter( - beta_1: float, - nx: int, - ny: int, - pixel_size: float, - beta_2: Optional[float] = None, - scale_break: Optional[float] = None, -): - """ - Generate a 2D low-pass power-law filter for FFT filtering. - - Parameters: - beta_1 (float): Power law exponent for frequencies < f1 (low frequencies) - nx (int): Number of columns (width) in the 2D field - ny (int): Number of rows (height) in the 2D field - pixel_size (float): Pixel size in km - beta_2 (float): Power law exponent for frequencies > f1 (high frequencies) Optional - scale_break (float): Break scale in km Optional - - Returns: - np.ndarray: 2D FFT low-pass filter - """ - - # Compute the frequency grid - freq_x = np.fft.fftfreq(nx, d=pixel_size) # Frequency in x-direction - freq_y = np.fft.fftfreq(ny, d=pixel_size) # Frequency in y-direction - - # 2D array with radial frequency - freq_r = np.sqrt(freq_x[:, None] ** 2 + freq_y[None, :] ** 2) - - # Initialize the radial 2D filter - filter_r = np.ones_like(freq_r) # Initialize with ones - f_zero = freq_x[1] - - if beta_2 is not None and scale_break is not None: - b1 = beta_1 / 2.0 - b2 = beta_2 / 2.0 - - f1 = 1 / scale_break # Convert scale break to frequency domain - weight = (f1 / f_zero) ** b1 - - # Apply the power-law function for a **low-pass filter** - # Handle division by zero at freq = 0 - with np.errstate(divide="ignore", invalid="ignore"): - mask_low = freq_r < f1 # Frequencies lower than the break - mask_high = ~mask_low # Frequencies higher than or equal to the break - - filter_r[mask_low] = (freq_r[mask_low] / f_zero) ** b1 - filter_r[mask_high] = weight * (freq_r[mask_high] / f1) ** b2 - - # Ensure DC component (zero frequency) is handled properly - filter_r[freq_r == 0] = 1 # Preserve the mean component - else: - b1 = beta_1 / 2.0 - mask = freq_r > 0 - filter_r[mask] = (freq_r[mask] / f_zero) ** b1 - filter_r[freq_r == 0] = 1 # Preserve the mean component - - return filter_r