Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 176 additions & 1 deletion src/climatebenchpress/compressor/scripts/compress.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import argparse
import json
import math
import traceback
from collections.abc import Container, Mapping
from pathlib import Path
Expand All @@ -20,6 +21,8 @@
from ..compressors.abc import Compressor, ErrorBound, NamedPerVariableCodec
from ..monitor import progress_bar

TARGET_CHUNK_SIZE = 4 * 1e6


def compress(
basepath: Path = Path(),
Expand All @@ -28,6 +31,7 @@ def compress(
exclude_compressor: Container[str] = tuple(),
include_compressor: None | Container[str] = None,
data_loader_basepath: None | Path = None,
chunked: bool = False,
progress: bool = True,
):
"""Compress datasets with compressors.
Expand All @@ -49,6 +53,8 @@ def compress(
data_loader_basepath : None | Path
Base path for the data loader datasets. If `None`, defaults to `basepath / .. / data-loader`.
Input datasets will be loaded from `data_loader_basepath / datasets`.
chunked : bool
Whether to chunk the input data.
progress : bool
Whether to show a progress bar during compression.
"""
Expand Down Expand Up @@ -84,6 +90,14 @@ def compress(
ds_mins[vs] = ds[v].min().values.item()
ds_maxs[vs] = ds[v].max().values.item()

if chunked:
for v in ds:
word_size = ds[v].dtype.itemsize
optimal_chunks = get_optimal_chunkshape(
ds[v], TARGET_CHUNK_SIZE, word_size=word_size
)
ds[v] = ds[v].chunk(optimal_chunks)

error_bounds = get_error_bounds(datasets_error_bounds, dataset.parent.name)
registry: Mapping[str, type[Compressor]] = Compressor.registry # type: ignore
for compressor in registry.values():
Expand All @@ -100,9 +114,12 @@ def compress(

for compr_name, named_codecs in compressor_variants.items():
for named_codec in named_codecs:
dataset_name = dataset.parent.name
if chunked:
dataset_name += "-chunked"
compressed_dataset = (
compressed_datasets
/ dataset.parent.name
/ dataset_name
/ named_codec.name
/ compr_name
)
Expand Down Expand Up @@ -204,6 +221,162 @@ def get_error_bounds(
]


def get_optimal_chunkshape(f, volume, word_size=4, logger=None):
"""
Given a CF field, f get an optimal chunk shape using knowledge about the various dimensions.
Our working assumption is that we want to have, for
- hourly data, chunk shapes which are multiples of 12 in the time dimension
- sub-daily data, chunk shapes which divide into a small multiple of 24
- daily data, chunk shapes which are a multiple of 10
- monthly data, chunk shapes which are a multiple of 12

Function adapted from: https://github.com/NCAS-CMS/cfs3/blob/390ee593bfea1d926d6b814636b02fb4c430f91e/cfs3/cfchunking.py
"""

default = get_chunkshape(np.array(f.data.shape), volume, word_size, logger)
t_axis_name = f.cf.axes.get("T", None)
if t_axis_name is None:
raise ValueError(
"Cannot identify a time axis, optimal chunk shape not possible"
)
t_data = f.cf["T"]
interval = "u"
if len(t_data) > 1:
assert t_data.dtype == np.dtype("datetime64[ns]")
# Input data is in ns. Convert delta unit to "day".
t_delta = (t_data[1] - t_data[0]) / np.timedelta64(1, "D")
t_delta = t_delta.item()

if t_delta < 1:
t_delta = round(t_delta * 24)
if t_delta == 1:
interval = "h"
else:
interval = int(24 / t_delta)
elif t_delta == 1:
interval = "d"
else:
interval = "m"

try:
index = f.dims.index(t_axis_name)
guess = default[index]
match interval:
case "h":
if guess < 3:
default[index] = 2
elif guess < 6:
default[index] = 4
elif guess < 12:
default[index] = 6
elif guess < 19:
default[index] = 12
else:
default[index] = round(guess / 24) * 24
case "d":
default[index] = round(guess / 10) * 10
case "m":
default[index] = round(guess / 12) * 12
case "u":
pass
case _:
default[index] = int(guess / interval) * interval
if default[index] == 0:
default[index] = guess # well that clearly won't work so revert
if guess != default[index] and logger:
logger.info(f"Time chunk changed from {guess} to {default[index]}")
except ValueError:
pass
return default


def get_chunkshape(shape, volume, word_size=4, logger=None, scale_tol=0.8):
"""
Given a shape tuple, and byte size for the elements, calculate a suitable chunk shape
for a given volume (in bytes). (We use word instead of dtype in case the user
changes the data type within the writing operation.)

Function adapted from: https://github.com/NCAS-CMS/cfs3/blob/390ee593bfea1d926d6b814636b02fb4c430f91e/cfs3/cfchunking.py
"""

def constrained_largest_divisor(number, constraint):
"""
Find the largest divisor of number which is less than the constraint
"""
for i in range(int(constraint), 1, -1):
if number % i == 0:
return i
return 1

def revise(dimension, guess):
"""
We need the largest integer (down) less than guess
which is a factor of dimension, and we need
to know how much smaller than guess it is,
so that other dimensions can be scaled out.
"""
old_guess = guess
# there must be a more elegant way of doing this
guess = constrained_largest_divisor(dimension, old_guess)
scale_factor = old_guess / guess
return scale_factor, guess

v = volume / word_size
size = np.prod(shape)

n_chunks = int(size / v)
root = v ** (1 / shape.size)

# first get a scaled set of initial guess divisors
initial_root = np.full(shape.size, root)
ratios = [x / min(shape) for x in shape]
other_root = 1.0 / (shape.size - 1)
indices = list(range(shape.size))
for i in indices:
factor = ratios[i] ** other_root
initial_root[i] = initial_root[i] * ratios[i]
for j in indices:
if j == i:
continue
initial_root[j] = initial_root[j] / factor

weights_scaling = np.ones(shape.size)

results = []
remaining = 1
for i in indices:
# can't use zip because we are modifying weights in the loop
d = shape[i]
initial_guess = math.ceil(initial_root[i] * weights_scaling[i])
if d % initial_guess == 0:
results.append(initial_guess)
else:
scale_factor, next_guess = revise(d, initial_guess)
results.append(next_guess)
if remaining < shape.size:
scale_factor = scale_factor ** (1 / (shape.size - remaining))
weights_scaling[remaining:] = np.full(
shape.size - remaining, scale_factor
)
remaining += 1
# fix up the last indice as we could have drifted quite small
if i == indices[-1]:
size_so_far = np.prod(np.array(results))
scale_miss = size_so_far / v
if scale_miss < scale_tol:
constraint = results[-1] / (scale_miss)
scaled_up = constrained_largest_divisor(shape[-1], constraint)
results[-1] = scaled_up

if logger:
actual_n_chunks = int(np.prod(np.divide(shape, np.array(results))))
cvolume = int(np.prod(np.array(results)) * 4)
logger.info(
f"Chunk size {results} - wanted {int(n_chunks)}/{int(volume)}B will get {actual_n_chunks}/{cvolume}B"
)
return results


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--exclude-dataset", type=str, nargs="+", default=[])
Expand All @@ -214,6 +387,7 @@ def get_error_bounds(
parser.add_argument(
"--data-loader-basepath", type=Path, default=Path() / ".." / "data-loader"
)
parser.add_argument("--chunked", action="store_true", default=False)
args = parser.parse_args()

compress(
Expand All @@ -223,5 +397,6 @@ def get_error_bounds(
exclude_compressor=args.exclude_compressor,
include_compressor=args.include_compressor,
data_loader_basepath=args.data_loader_basepath,
chunked=args.chunked,
progress=True,
)
7 changes: 6 additions & 1 deletion src/climatebenchpress/compressor/scripts/compute_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,12 @@ def compute_metrics(
/ compressor.stem
)
compressed_dataset_path = compressed_dataset / "decompressed.zarr"
uncompressed_dataset = datasets / dataset.name / "standardized.zarr"
uncompressed_dataset_name = dataset.name
if dataset.name.endswith("-chunked"):
uncompressed_dataset_name = dataset.name.removesuffix("-chunked")
uncompressed_dataset = (
datasets / uncompressed_dataset_name / "standardized.zarr"
)
if not compressed_dataset_path.exists():
print(f"No compressed dataset at {compressed_dataset_path}")
continue
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@ def concatenate_metrics(basepath: Path = Path()):
if not dataset.is_dir():
continue

with (error_bounds_dir / dataset.name / "error_bounds.json").open() as f:
dataset_name = dataset.name
if dataset.name.endswith("-chunked"):
dataset_name = dataset_name.removesuffix("-chunked")

with (error_bounds_dir / dataset_name / "error_bounds.json").open() as f:
error_bound_list = json.load(f)

for error_bound in dataset.iterdir():
Expand Down