diff --git a/CHANGELOG.md b/CHANGELOG.md index 2b1826e0..18dc91ac 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [0.4.0] - Unreleased ### Added +- Functions to concatenate input damages across batches. ([PR #83](https://github.com/ClimateImpactLab/dscim/pull/83), [@davidrzhdu](https://github.com/davidrzhdu)) - New unit tests for [dscim/utils/input_damages.py](https://github.com/ClimateImpactLab/dscim/blob/main/src/dscim/preprocessing/input_damages.py). ([PR #68](https://github.com/ClimateImpactLab/dscim/pull/68), [@davidrzhdu](https://github.com/davidrzhdu)) - New unit tests for [dscim/utils/rff.py](https://github.com/ClimateImpactLab/dscim/blob/main/src/dscim/utils/rff.py). ([PR #73](https://github.com/ClimateImpactLab/dscim/pull/73), [@JMGilbert](https://github.com/JMGilbert)) - New unit tests for [dscim/dscim/preprocessing.py](https://github.com/ClimateImpactLab/dscim/blob/main/src/dscim/preprocessing/preprocessing.py). ([PR #67](https://github.com/ClimateImpactLab/dscim/pull/67), [@JMGilbert](https://github.com/JMGilbert)) @@ -23,6 +24,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Remove old/unnecessary files. ([PR #57](https://github.com/ClimateImpactLab/dscim/pull/57), [@JMGilbert](https://github.com/JMGilbert)) - Remove unused “save_path” and “ec_cls” from `read_energy_files_parallel()`. ([PR #56](https://github.com/ClimateImpactLab/dscim/pull/56), [@davidrzhdu](https://github.com/davidrzhdu)) ### Fixed +- Make all input damages output files with correct chunksizes. ([PR #83](https://github.com/ClimateImpactLab/dscim/pull/83), [@JMGilbert](https://github.com/JMGilbert)) - Add `.load()` to every loading of population data from EconVars. ([PR #82](https://github.com/ClimateImpactLab/dscim/pull/82), [@davidrzhdu](https://github.com/davidrzhdu)) - Make `compute_ag_damages` function correctly save outputs in float32. ([PR #72](https://github.com/ClimateImpactLab/dscim/pull/72) and [PR #82](https://github.com/ClimateImpactLab/dscim/pull/82), [@davidrzhdu](https://github.com/davidrzhdu)) - Make rff damage functions read in and save out in the proper filepath structure. ([PR #79](https://github.com/ClimateImpactLab/dscim/pull/79), [@JMGilbert](https://github.com/JMGilbert)) diff --git a/src/dscim/preprocessing/input_damages.py b/src/dscim/preprocessing/input_damages.py index 005513e5..ab368ff0 100644 --- a/src/dscim/preprocessing/input_damages.py +++ b/src/dscim/preprocessing/input_damages.py @@ -3,7 +3,6 @@ """ import os -import glob import re import logging import warnings @@ -95,6 +94,50 @@ def _parse_projection_filesys(input_path, query="exists==True"): return df.query(query) +def concatenate_damage_output(damage_dir, basename, save_path): + """Concatenate labor/energy damage output across batches. + + Parameters + ---------- + damage_dir str + Directory containing separate labor/energy damage output files by batches. + basename str + Prefix of the damage output filenames (ex. {basename}_batch0.zarr) + save_path str + Path to save concatenated file in .zarr format + """ + paths = [ + f"{damage_dir}/{basename}_{b}.zarr" + for b in ["batch" + str(i) for i in range(0, 15)] + ] + data = xr.open_mfdataset(paths=paths, engine="zarr") + + for v in data: + del data[v].encoding["chunks"] + + chunkies = { + "batch": 15, + "rcp": 1, + "gcm": 1, + "model": 1, + "ssp": 1, + "region": -1, + "year": 10, + } + + data = data.chunk(chunkies) + + for v in list(data.coords.keys()): + if data.coords[v].dtype == object: + data.coords[v] = data.coords[v].astype("unicode") + data.coords["batch"] = data.coords["batch"].astype("unicode") + for v in list(data.variables.keys()): + if data[v].dtype == object: + data[v] = data[v].astype("unicode") + + data.to_zarr(save_path, mode="w") + + def calculate_labor_impacts(input_path, file_prefix, variable, val_type): """Calculate impacts for labor results. @@ -371,7 +414,7 @@ def process_batch(g): batches = [ds for ds in batches if ds is not None] chunkies = { "rcp": 1, - "region": 24378, + "region": -1, "gcm": 1, "year": 10, "model": 1, @@ -738,12 +781,21 @@ def prep( ).expand_dims({"gcm": [gcm]}) damages = damages.chunk( - {"batch": 15, "ssp": 1, "model": 1, "rcp": 1, "gcm": 1, "year": 10} + { + "batch": 15, + "ssp": 1, + "model": 1, + "rcp": 1, + "gcm": 1, + "year": 10, + "region": -1, + } ) damages.coords.update({"batch": [f"batch{i}" for i in damages.batch.values]}) # convert to EPA VSL damages = damages * 0.90681089 + damages = damages.astype(np.float32) for v in list(damages.coords.keys()): if damages.coords[v].dtype == object: @@ -790,6 +842,15 @@ def coastal_inputs( ) else: d = d.sel(adapt_type=adapt_type, vsl_valuation=vsl_valuation, drop=True) + chunkies = { + "batch": 15, + "ssp": 1, + "model": 1, + "slr": 1, + "year": 10, + "region": -1, + } + d = d.chunk(chunkies) d.to_zarr( f"{path}/coastal_damages_{version}-{adapt_type}-{vsl_valuation}.zarr", consolidated=True, diff --git a/src/dscim/preprocessing/preprocessing.py b/src/dscim/preprocessing/preprocessing.py index 67fb2eec..a253b389 100644 --- a/src/dscim/preprocessing/preprocessing.py +++ b/src/dscim/preprocessing/preprocessing.py @@ -102,6 +102,24 @@ def reduce_damages( xr.open_zarr(damages).chunks["batch"][0] == 15 ), "'batch' dim on damages does not have chunksize of 15. Please rechunk." + if "coastal" not in sector: + chunkies = { + "rcp": 1, + "region": -1, + "gcm": 1, + "year": 10, + "model": 1, + "ssp": 1, + } + else: + chunkies = { + "region": -1, + "slr": 1, + "year": 10, + "model": 1, + "ssp": 1, + } + ce_batch_dims = [i for i in gdppc.dims] + [ i for i in ds.dims if i not in gdppc.dims and i != "batch" ] @@ -110,15 +128,14 @@ def reduce_damages( i for i in gdppc.region.values if i in ce_batch_coords["region"] ] ce_shapes = [len(ce_batch_coords[c]) for c in ce_batch_dims] - ce_chunks = [xr.open_zarr(damages).chunks[c][0] for c in ce_batch_dims] template = xr.DataArray( - da.empty(ce_shapes, chunks=ce_chunks), + da.empty(ce_shapes), dims=ce_batch_dims, coords=ce_batch_coords, - ) + ).chunk(chunkies) - other = xr.open_zarr(damages) + other = xr.open_zarr(damages).chunk(chunkies) out = other.map_blocks( ce_from_chunk, @@ -205,7 +222,21 @@ def sum_AMEL( for sector in sectors: print(f"Opening {sector},{params[sector]['sector_path']}") ds = xr.open_zarr(params[sector]["sector_path"], consolidated=True) - ds = ds[params[sector][var]].rename(var) + ds = ( + ds[params[sector][var]] + .rename(var) + .chunk( + { + "batch": 15, + "ssp": 1, + "model": 1, + "rcp": 1, + "gcm": 1, + "year": 10, + "region": -1, + } + ) + ) ds = xr.where(np.isinf(ds), np.nan, ds) datasets.append(ds) diff --git a/tests/test_input_damages.py b/tests/test_input_damages.py index 76b28a1b..d6d8e0b6 100644 --- a/tests/test_input_damages.py +++ b/tests/test_input_damages.py @@ -9,6 +9,7 @@ from dscim.menu.simple_storage import EconVars from dscim.preprocessing.input_damages import ( _parse_projection_filesys, + concatenate_damage_output, calculate_labor_impacts, concatenate_labor_damages, calculate_labor_batch_damages, @@ -31,7 +32,7 @@ def test_parse_projection_filesys(tmp_path): """ Test that parse_projection_filesys correctly retrieves projection system output structure """ - rcp = ["rcp85", "rcp45"] + rcp = ["rcp45", "rcp85"] gcm = ["ACCESS1-0", "GFDL-CM3"] model = ["high", "low"] ssp = [f"SSP{n}" for n in range(2, 4)] @@ -45,14 +46,14 @@ def test_parse_projection_filesys(tmp_path): os.makedirs(os.path.join(tmp_path, b, r, g, m, s)) out_expected = { - "batch": list(chain(repeat("batch9", 16), repeat("batch6", 16))), - "rcp": list(chain(repeat("rcp85", 8), repeat("rcp45", 8))) * 2, + "batch": list(chain(repeat("batch6", 16), repeat("batch9", 16))), + "rcp": list(chain(repeat("rcp45", 8), repeat("rcp85", 8))) * 2, "gcm": list(chain(repeat("ACCESS1-0", 4), repeat("GFDL-CM3", 4))) * 4, "model": list(chain(repeat("high", 2), repeat("low", 2))) * 8, "ssp": ["SSP2", "SSP3"] * 16, "path": [ os.path.join(tmp_path, b, r, g, m, s) - for b in ["batch9", "batch6"] + for b in ["batch6", "batch9"] for r in rcp for g in gcm for m in model @@ -64,11 +65,83 @@ def test_parse_projection_filesys(tmp_path): df_out_expected = pd.DataFrame(out_expected) df_out_actual = _parse_projection_filesys(input_path=tmp_path) + df_out_actual = df_out_actual.sort_values( + by=["batch", "rcp", "gcm", "model", "ssp"] + ) df_out_actual.reset_index(drop=True, inplace=True) pd.testing.assert_frame_equal(df_out_expected, df_out_actual) +def test_concatenate_damage_output(tmp_path): + """ + Test that concatenate_damage_output correctly concatenates damages across batches and saves to a single zarr file + """ + d = os.path.join(tmp_path, "concatenate_in") + if not os.path.exists(d): + os.makedirs(d) + + for b in ["batch" + str(i) for i in range(0, 15)]: + ds_in = xr.Dataset( + { + "delta_rebased": ( + ["ssp", "rcp", "model", "gcm", "batch", "year", "region"], + np.full((2, 2, 2, 2, 1, 2, 2), 1).astype(object), + ), + "histclim_rebased": ( + ["ssp", "rcp", "model", "gcm", "batch", "year", "region"], + np.full((2, 2, 2, 2, 1, 2, 2), 2), + ), + }, + coords={ + "batch": (["batch"], [b]), + "gcm": (["gcm"], np.array(["ACCESS1-0", "BNU-ESM"], dtype=object)), + "model": (["model"], ["IIASA GDP", "OECD Env-Growth"]), + "rcp": (["rcp"], ["rcp45", "rcp85"]), + "region": (["region"], ["ZWE.test_region", "USA.test_region"]), + "ssp": (["ssp"], ["SSP2", "SSP3"]), + "year": (["year"], [2020, 2099]), + }, + ) + + infile = os.path.join(d, f"test_insuffix_{b}.zarr") + + ds_in.to_zarr(infile) + + ds_out_expected = xr.Dataset( + { + "delta_rebased": ( + ["ssp", "rcp", "model", "gcm", "batch", "year", "region"], + np.full((2, 2, 2, 2, 15, 2, 2), 1), + ), + "histclim_rebased": ( + ["ssp", "rcp", "model", "gcm", "batch", "year", "region"], + np.full((2, 2, 2, 2, 15, 2, 2), 2), + ), + }, + coords={ + "batch": (["batch"], ["batch" + str(i) for i in range(0, 15)]), + "gcm": (["gcm"], ["ACCESS1-0", "BNU-ESM"]), + "model": (["model"], ["IIASA GDP", "OECD Env-Growth"]), + "rcp": (["rcp"], ["rcp45", "rcp85"]), + "region": (["region"], ["ZWE.test_region", "USA.test_region"]), + "ssp": (["ssp"], ["SSP2", "SSP3"]), + "year": (["year"], [2020, 2099]), + }, + ) + + concatenate_damage_output( + damage_dir=d, + basename="test_insuffix", + save_path=os.path.join(d, "concatenate.zarr"), + ) + ds_out_actual = xr.open_zarr(os.path.join(d, "concatenate.zarr")).sel( + batch=["batch" + str(i) for i in range(0, 15)] + ) + + xr.testing.assert_equal(ds_out_expected, ds_out_actual) + + @pytest.fixture def labor_in_val_fixture(tmp_path): """ @@ -697,7 +770,9 @@ def energy_in_netcdf_fixture(tmp_path): "region", "year", ], - np.full((1, 1, 1, 1, 1, 2, 2), 2), + np.full((1, 1, 1, 1, 1, 2, 2), 2).astype( + object + ), ), }, coords={ @@ -1030,11 +1105,11 @@ def test_prep_mortality_damages( { "delta": ( ["gcm", "batch", "ssp", "rcp", "model", "year", "region"], - np.full((2, 2, 2, 2, 2, 2, 2), -0.90681089), + np.float32(np.full((2, 2, 2, 2, 2, 2, 2), -0.90681089)), ), "histclim": ( ["gcm", "batch", "ssp", "rcp", "model", "year", "region"], - np.full((2, 2, 2, 2, 2, 2, 2), 2 * 0.90681089), + np.float32(np.full((2, 2, 2, 2, 2, 2, 2), 2 * 0.90681089)), ), }, coords={