Skip to content

Commit

Permalink
Made linter happy
Browse files Browse the repository at this point in the history
  • Loading branch information
goord committed May 23, 2024
1 parent b52721b commit d0e8eaf
Show file tree
Hide file tree
Showing 4 changed files with 195 additions and 61 deletions.
36 changes: 29 additions & 7 deletions dales2zarr/convert_int8.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,43 @@
import argparse
import yaml
import xarray as xr
import yaml
import zarr
from dales2zarr.zarr_cast import multi_cast_to_int8


# Parse command-line arguments
def parse_args(arg_list: list[str] | None = None):
def parse_args(arg_list=None):
"""Parse command-line arguments for the convert_int8 script.
Args:
arg_list (list, optional): List of command-line arguments of type str. Defaults to None,
in which case sys.argv[1:] is used.
Returns:
argparse.Namespace: Parsed command-line arguments.
"""
parser = argparse.ArgumentParser(description="Convert input dataset to 8-bit integers and write to zarr")
parser.add_argument("--input", metavar="FILE", type=str, required=True, help="Path to the input dataset file")
parser.add_argument("--output", metavar="FILE", type=str, required=False, default=None, help="Path to the output zarr file")
parser.add_argument("--config", metavar="FILE", type=str, required=False, default=None, help="Path to the input configuration file (yaml)")
parser.add_argument("--input", metavar="FILE", type=str, required=True,
help="Path to the input dataset file")
parser.add_argument("--output", metavar="FILE", type=str, required=False, default=None,
help="Path to the output zarr file")
parser.add_argument("--config", metavar="FILE", type=str, required=False, default=None,
help="Path to the input configuration file (yaml)")
return parser.parse_args(args=arg_list)

def main(arg_list: list[str] | None = None):

def main(arg_list=None):
"""Convert the input dataset to int8 and save it in zarr format.
Args:
arg_list (list, optional): List of command-line arguments. Defaults to None, in which case sys.argv[1:] is used.
Returns:
None
"""
# Parse command-line arguments
args = parse_args(arg_list)

# Read the input dataset from file
input_ds = xr.open_dataset(args.input)

Expand Down
75 changes: 41 additions & 34 deletions dales2zarr/zarr_cast.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
"""This module contains core functionalities to convert double-precision numerical data into
unsigned 8-bit integer values, suitable for visualization"""
"""Utilities for converting floating-point numerical data to unsigned 8-bit integers.
This module contains core functionalities to convert double-precision numerical data into
unsigned 8-bit integer values, suitable for visualization.
"""
import logging
import numpy as np
import xarray as xr
import logging

log = logging.getLogger(__name__)


def normalize_data(data, data_min, data_max, mode, epsilon=1e-10):
"""Normalize input data array and cast to 8-bit integer
"""Normalize input data array and cast to 8-bit integer.
Args:
data (array): Input numerical floating-point data to convert to 8-bit values
Expand Down Expand Up @@ -58,7 +61,7 @@ def normalize_data(data, data_min, data_max, mode, epsilon=1e-10):


def cast_to_int8_3d(input_ds, var_name, mode='linear', epsilon=1e-10):
"""Convert 3d xarray to 8-bit integers
"""Convert 3d xarray to 8-bit integers.
Args:
input_ds (xarray.Dataset): Input xarray dataset
Expand All @@ -82,15 +85,15 @@ def cast_to_int8_3d(input_ds, var_name, mode='linear', epsilon=1e-10):
input_data = input_ds[var_name]
glob_min = input_data.min().values
log.info(f'global minimum is {glob_min}')

# Compute the maximum values for each layer
layer_maxes = input_data.max(['xt', 'yt']).values
log.info(f'layer maxes are {layer_maxes}')

# Compute the global maximum
glob_max = np.max(layer_maxes)
log.info(f'computed min and max: {glob_min}, {glob_max}')

# Normalize the layer maxes and get the indices of the bottom and top layers
norms, _, _ = normalize_data(layer_maxes, glob_min, glob_max, mode, epsilon)
kbot, ktop, kmax = -1, -1, norms.shape[-1]
Expand All @@ -100,43 +103,44 @@ def cast_to_int8_3d(input_ds, var_name, mode='linear', epsilon=1e-10):
if np.max(norms[..., kmax - k - 1]) != 0 and ktop < 0:
ktop = kmax - k - 1
log.info(f'computed kbot and ktop: {kbot}, {ktop}')

# Get the heights associated with the bottom and top layers
zbot, ztop = input_ds.zt[kbot], input_ds.zt[ktop]
log.info(f'computed associated heights are: {zbot}, {ztop}')

# Create a new array with the appropriate shape
new_shape = list(input_data.shape)
z_index = 0 if len(new_shape) == 3 else 1
new_shape[z_index] = ktop - kbot + 1
arr = np.zeros(new_shape, dtype='uint8')

# Convert each time slice to 8-bit integers
if 'time' in input_ds.sizes:
ntimes = input_ds.sizes['time']
for itime in range(ntimes):
log.info(f'Converting time slice {itime} from {ntimes}...')
arr[itime, ...], s1, s2 = normalize_data(input_data.isel(time=itime).isel(zt=slice(kbot, ktop + 1)), glob_min, glob_max, mode, epsilon)
arr[itime, ...], s1, s2 = normalize_data(input_data.isel(time=itime).isel(zt=slice(kbot, ktop + 1)),
glob_min, glob_max, mode, epsilon)
else:
log.info("Converting single time slice...")
arr[...], s1, s2 = normalize_data(input_data.isel(zt=slice(kbot, ktop + 1)), glob_min, glob_max, mode, epsilon)

arr[...], s1, s2 = normalize_data(input_data.isel(zt=slice(kbot, ktop + 1)),
glob_min, glob_max, mode, epsilon)

# Create a new dataset with the converted variable
output = input_data.isel(zt=slice(kbot, ktop + 1)).to_dataset()
output[var_name].values = arr
output.attrs['lower_bound'] = s1
output.attrs['upper_bound'] = s2
output.attrs['zbot'] = zbot
output.attrs['ztop'] = ztop

return output


def cast_to_int8_2d(input_ds, var_name, mode='linear', epsilon=1e-10):
"""
Converts a 2D variable in a given xarray Dataset to an 8-bit integer representation.
"""Converts a 2D variable in a given xarray Dataset to an 8-bit integer representation.
Parameters:
Args:
input_ds (xarray.Dataset): The input dataset containing the variable to be converted.
var_name (str): The name of the variable to be converted.
mode (str, optional): The mode used for normalization. Default is 'linear'.
Expand All @@ -149,30 +153,32 @@ def cast_to_int8_2d(input_ds, var_name, mode='linear', epsilon=1e-10):
ValueError: If `mode` is not `linear` or `log`.
Examples:
# Convert a 2D variable named 'temperature' in the input dataset to an 8-bit integer representation
# Convert a 2D variable named 'temperature' in the input dataset to an 8-bit integer representation:
output_ds = cast_to_8bit_integer_2d(input_ds, 'temperature')
# Convert a 2D variable named 'humidity' in the input dataset to an 8-bit integer representation using 'log' mode
# Convert a 2D variable named 'humidity' in the input dataset to an 8-bit integer representation:
# using 'log' mode:
output_ds = cast_to_8bit_integer_2d(input_ds, 'humidity', mode='log')
"""
# Compute the global minimum and maximum values of the input variable
glob_min = input_ds[var_name].min().values
glob_max = input_ds[var_name].max().values
log.info(f'computed min and max: {glob_min}, {glob_max}')

# Create an array of zeros with the same shape as the input variable
arr = np.zeros(list(input_ds[var_name].shape), dtype='uint8')

# Convert each time slice of the input variable to 8-bit integers
if 'time' in input_ds.sizes:
num_times = input_ds.sizes['time']
for itime in range(num_times):
log.info(f'Converting time slice {itime} from {num_times}...')
arr[itime,...], s1, s2 = normalize_data(input_ds[var_name].isel(time=itime), glob_min, glob_max, mode)
arr[itime,...], s1, s2 = normalize_data(input_ds[var_name].isel(time=itime),
glob_min, glob_max, mode, epsilon)
else:
log.info('Converting single time slice...')
arr[...], s1, s2 = normalize_data(input_ds[var_name], glob_min, glob_max, mode)
arr[...], s1, s2 = normalize_data(input_ds[var_name], glob_min, glob_max, mode, epsilon)

# Create a new dataset with the converted variable
output = input_ds[var_name].to_dataset()
output[var_name].values = arr
Expand All @@ -182,13 +188,13 @@ def cast_to_int8_2d(input_ds, var_name, mode='linear', epsilon=1e-10):


def cast_to_int8(input_ds, input_var, output_var=None, mode='linear', epsilon=1e-10):
"""
Casts a variable in the input dataset to an 8-bit integer.
"""Casts a variable in the input dataset to an 8-bit integer.
Parameters:
Args:
input_ds (xarray.Dataset): The input dataset.
input_var (str): The name of the variable to cast.
output_var (str, optional): The name of the output variable. If not provided, the input variable name will be used.
output_var (str, optional): The name of the output variable.
If not provided, the input variable name will be used.
mode (str, optional): The casting mode. Defaults to 'linear'.
epsilon (float, optional): Offset used for the logarithmic mapping.
Expand All @@ -204,16 +210,17 @@ def cast_to_int8(input_ds, input_var, output_var=None, mode='linear', epsilon=1e
input_ds = input_ds.rename_vars({input_var: var_name})
dims = input_ds[var_name].coords
if "zt" in dims or "zm" in dims:
return cast_to_int8_3d(input_ds, var_name, mode, epsilon).rename_dims({"zt": "z_" + var_name}).rename_vars({"zt": "z_" + var_name})
return (cast_to_int8_3d(input_ds, var_name, mode, epsilon)
.rename_dims({"zt": "z_" + var_name})
.rename_vars({"zt": "z_" + var_name}))
else:
return cast_to_int8_2d(input_ds, var_name, mode, epsilon)


def multi_cast_to_int8(input_ds, input_config):
"""
Casts multiple variables in the input dataset to 8-bit integers.
"""Casts multiple variables in the input dataset to 8-bit integers.
Parameters:
Args:
input_ds (xarray.Dataset): The input dataset.
input_config (dict): A dictionary containing the configuration for casting the variables.
Expand All @@ -226,4 +233,4 @@ def multi_cast_to_int8(input_ds, input_config):
if int8_var is not None:
outputs.append(int8_var)
variables.append(var_options.get('output_var', input_var))
return xr.merge(outputs), variables
return xr.merge(outputs), variables
43 changes: 29 additions & 14 deletions tests/test_convert_int8.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,24 @@
import os
import tempfile
import numpy as np
import yaml
import xarray as xr
import yaml
from dales2zarr.convert_int8 import main

# These tests have been created with the help of github copilot

def test_main_with_default_config():
"""Test the main function with the default configuration.
This test case creates a temporary directory to store the output zarr file.
It sets the input and output file paths and creates a sample input dataset.
The input dataset is saved to a netCDF file and then passed to the main function.
The test checks if the output zarr file exists and reads the output dataset from it.
It also checks if the output dataset has the expected variables and data types.
Returns:
None
"""
# Create a temporary directory to store the output zarr file
with tempfile.TemporaryDirectory() as temp_dir:
# Set the input and output file paths
Expand Down Expand Up @@ -38,45 +49,49 @@ def test_main_with_default_config():


def test_main_with_custom_config():
# Create a temporary directory to store the output zarr file
"""Test the main function with a custom configuration.
This test case performs the following steps:
1. Creates a temporary directory to store the output zarr file.
2. Sets the input and output file paths.
3. Creates a sample input dataset with ql and qr variables.
4. Saves the input dataset to a netCDF file.
5. Creates a sample input configuration.
6. Saves the input configuration to a yaml file.
7. Calls the main function with the input, output, and config file paths.
8. Checks if the output zarr file exists.
9. Reads the output dataset from the zarr file.
10. Checks if the output dataset has the expected variables.
11. Checks if the output dataset variables have the expected data type.
12. Checks if the output dataset variables have the expected values.
"""
with tempfile.TemporaryDirectory() as temp_dir:
# Set the input and output file paths
input_file = os.path.join(temp_dir, "input.nc")
output_file = os.path.join(temp_dir, "output.zarr")
config_file = os.path.join(temp_dir, "config.yaml")

# Create a sample input dataset
ql_input_data = np.array([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]])
qr_input_data = np.array([[[10.0, 20.0], [30.0, 40.0]], [[50.0, 60.0], [70.0, 80.0]]])
input_ds = xr.Dataset({'ql': (['zt', 'yt', 'xt'], ql_input_data), 'qr': (['zt', 'yt', 'xt'], qr_input_data)})

# Save the input dataset to a netCDF file
input_ds.to_netcdf(input_file)

# Create a sample input configuration
input_config = {"ql": {"mode": "log"}, "qr": {"mode": "linear"}}

# Save the input configuration to a yaml file
with open(config_file, "w") as f:
yaml.safe_dump(input_config, f)

# Call the main function
main(["--input", input_file, "--output", output_file, "--config", config_file])

# Check if the output zarr file exists
assert os.path.exists(output_file)

# Read the output dataset from the zarr file
output_data = xr.open_zarr(output_file)

# Check if the output dataset has the expected variables
assert "ql" in output_data
assert "qr" in output_data

# Check if the output dataset variables have the expected data type
assert output_data["ql"].dtype == "uint8"
assert output_data["qr"].dtype == "uint8"

# Check if the output dataset variables have the expected values
assert output_data["ql"].values.flat[:3].tolist() == [0, 84, 134]
assert output_data["qr"].values.flat[:3].tolist() == [0, 36, 72]
assert output_data["qr"].values.flat[:3].tolist() == [0, 36, 72]
Loading

0 comments on commit d0e8eaf

Please sign in to comment.