Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added icon to nwp providers #72

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
132 changes: 128 additions & 4 deletions ocf_data_sampler/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
NWP_PROVIDERS = [
"ukv",
"ecmwf",
"gfs"
"gfs",
"icon_eu",
]
# TODO add ICON


def _to_data_array(d):
Expand Down Expand Up @@ -173,16 +173,140 @@ def __getitem__(self, key):
GFS_STD = _to_data_array(GFS_STD)
GFS_MEAN = _to_data_array(GFS_MEAN)

# ------ ICON-EU
# Statistics for ICON-EU variables
ICON_EU_STD = {
"alb_rad": 13.7881,
"alhfl_s": 73.7198,
"ashfl_s": 54.8027,
"asob_s": 55.8319,
"asob_t": 74.9360,
"aswdifd_s": 21.4940,
"aswdifu_s": 18.7688,
"aswdir_s": 54.4683,
"athb_s": 34.8575,
"athb_t": 42.9108,
"aumfl_s": 0.1460,
"avmfl_s": 0.1892,
"cape_con": 32.2570,
"cape_ml": 106.3998,
"clch": 39.9324,
"clcl": 36.3961,
"clcm": 41.1690,
"clct": 34.7696,
"clct_mod": 0.4227,
"cldepth": 0.1739,
"h_snow": 0.9012,
"hbas_con": 1306.6632,
"htop_con": 1810.5665,
"htop_dc": 459.0422,
"hzerocl": 1144.6469,
"pmsl": 1103.3301,
"ps": 4761.3184,
"qv_2m": 0.0024,
"qv_s": 0.0038,
"rain_con": 1.7097,
"rain_gsp": 4.2654,
"relhum_2m": 15.3779,
"rho_snow": 120.2461,
"runoff_g": 0.7410,
"runoff_s": 2.1930,
"snow_con": 1.1432,
"snow_gsp": 1.8154,
"snowlmt": 656.0699,
"synmsg_bt_cl_ir10.8": 17.9438,
"t_2m": 7.7973,
"t_g": 8.7053,
"t_snow": 134.6874,
"tch": 0.0052,
"tcm": 0.0133,
"td_2m": 7.1460,
"tmax_2m": 7.8218,
"tmin_2m": 7.8346,
"tot_prec": 5.6312,
"tqc": 0.0976,
"tqi": 0.0247,
"u_10m": 3.8351,
"v_10m": 5.0083,
"vmax_10m": 5.5037,
"w_snow": 286.1510,
"ww": 27.2974,
"z0": 0.3901,
}

ICON_EU_MEAN = {
"alb_rad": 15.4437,
"alhfl_s": -54.9398,
"ashfl_s": -19.4684,
"asob_s": 40.9305,
"asob_t": 61.9244,
"aswdifd_s": 19.7813,
"aswdifu_s": 8.8328,
"aswdir_s": 29.9820,
"athb_s": -53.9873,
"athb_t": -212.8088,
"aumfl_s": 0.0558,
"avmfl_s": 0.0078,
"cape_con": 16.7397,
"cape_ml": 21.2189,
"clch": 26.4262,
"clcl": 57.1591,
"clcm": 36.1702,
"clct": 72.9254,
"clct_mod": 0.5561,
"cldepth": 0.1356,
"h_snow": 0.0494,
"hbas_con": 108.4975,
"htop_con": 433.0623,
"htop_dc": 454.0859,
"hzerocl": 1696.6272,
"pmsl": 101778.8281,
"ps": 99114.4766,
"qv_2m": 0.0049,
"qv_s": 0.0065,
"rain_con": 0.4869,
"rain_gsp": 0.9783,
"relhum_2m": 78.2258,
"rho_snow": 62.5032,
"runoff_g": 0.1301,
"runoff_s": 0.4119,
"snow_con": 0.2188,
"snow_gsp": 0.4317,
"snowlmt": 1450.3241,
"synmsg_bt_cl_ir10.8": 265.0639,
"t_2m": 278.8212,
"t_g": 279.9216,
"t_snow": 162.5582,
"tch": 0.0047,
"tcm": 0.0091,
"td_2m": 274.9544,
"tmax_2m": 279.3550,
"tmin_2m": 278.2519,
"tot_prec": 2.1158,
"tqc": 0.0424,
"tqi": 0.0108,
"u_10m": 1.1902,
"v_10m": -0.4733,
"vmax_10m": 8.4152,
"w_snow": 14.5936,
"ww": 15.3570,
"z0": 0.2386,
}

ICON_EU_STD = _to_data_array(ICON_EU_STD)
ICON_EU_MEAN = _to_data_array(ICON_EU_MEAN)

NWP_STDS = NWPStatDict(
ukv=UKV_STD,
ecmwf=ECMWF_STD,
gfs=GFS_STD
gfs=GFS_STD,
icon_eu=ICON_EU_STD,
)
NWP_MEANS = NWPStatDict(
ukv=UKV_MEAN,
ecmwf=ECMWF_MEAN,
gfs=GFS_MEAN
gfs=GFS_MEAN,
icon_eu=ICON_EU_MEAN,
)

# ------ Satellite
Expand Down
4 changes: 3 additions & 1 deletion ocf_data_sampler/load/nwp/nwp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from ocf_data_sampler.load.nwp.providers.ukv import open_ukv
from ocf_data_sampler.load.nwp.providers.ecmwf import open_ifs

from ocf_data_sampler.load.nwp.providers.icon import open_icon_eu

def open_nwp(zarr_path: str | list[str], provider: str) -> xr.DataArray:
"""Opens NWP zarr
Expand All @@ -16,6 +16,8 @@ def open_nwp(zarr_path: str | list[str], provider: str) -> xr.DataArray:
_open_nwp = open_ukv
elif provider.lower() == "ecmwf":
_open_nwp = open_ifs
elif provider.lower() == "icon-eu":
Sukh-P marked this conversation as resolved.
Show resolved Hide resolved
_open_nwp = open_icon_eu
else:
raise ValueError(f"Unknown provider: {provider}")
return _open_nwp(zarr_path)
Expand Down
48 changes: 48 additions & 0 deletions ocf_data_sampler/load/nwp/providers/icon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""DWD ICON Loading"""

import pandas as pd
from ocf_data_sampler.load.utils import check_time_unique_increasing, make_spatial_coords_increasing
import xarray as xr
import fsspec

from ocf_data_sampler.load.nwp.providers.utils import open_zarr_paths

def remove_isobaric_lelvels_from_coords(nwp: xr.Dataset) -> xr.Dataset:
"""
Removes the isobaric levels from the coordinates of the NWP data

Args:
nwp: NWP data

Returns:
NWP data without isobaric levels in the coordinates
"""
variables_to_drop = [var for var in nwp.data_vars if 'isobaricInhPa' in nwp[var].dims]
return nwp.drop_vars(["isobaricInhPa"] + variables_to_drop)

def open_icon_eu(zarr_path) -> xr.Dataset:
"""
Opens the ICON data

ICON EU Data is on a regular lat/lon grid
It has data on multiple pressure levels, as well as the surface
Each of the variables is its own data variable

Args:
zarr_path: Path to the zarr to open

Returns:
Xarray DataArray of the NWP data
"""
# Open the data
nwp = open_zarr_paths(zarr_path, time_dim="time")
nwp = nwp.rename({"time": "init_time_utc"})
# Sanity checks.
check_time_unique_increasing(nwp.init_time_utc)
# 0–78 one hour steps, rest 3 hour steps
nwp = nwp.isel(step=slice(0, 78))
nwp = remove_isobaric_lelvels_from_coords(nwp)
nwp = nwp.to_array().rename({"variable": "channel"})
nwp = nwp.transpose('init_time_utc', 'step', 'channel', 'latitude', 'longitude')
nwp = make_spatial_coords_increasing(nwp, x_coord="longitude", y_coord="latitude")
return nwp
39 changes: 39 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,45 @@ def nwp_ecmwf_zarr_path(session_tmp_path, ds_nwp_ecmwf):
ds.to_zarr(zarr_path)
yield zarr_path

@pytest.fixture(scope="session")
def icon_eu_zarr_path(session_tmp_path):
date = "20211101"
hours = ["00", "06"]
paths = []

for hour in hours:
time = f"{date}_{hour}"
ds = xr.Dataset(
coords={
'isobaricInhPa': [50.0, 500.0, 700.0, 850.0, 950.0, 1000.0],
'latitude': np.linspace(29.5, 35.69, 100),
'longitude': np.linspace(-23.5, -17.31, 100),
'step': pd.timedelta_range(start='0h', end='5D', periods=93),
'time': pd.Timestamp(f"2021-11-01T{hour}:00:00"),
},
data_vars={
't': (('step', 'isobaricInhPa', 'latitude', 'longitude'),
np.random.rand(93, 6, 100, 100).astype(np.float32)),
'u_10m': (('step', 'latitude', 'longitude'),
np.random.rand(93, 100, 100).astype(np.float32)),
'v_10m': (('step', 'latitude', 'longitude'),
np.random.rand(93, 100, 100).astype(np.float32)),
},
attrs={
'Conventions': 'CF-1.7',
'GRIB_centre': 'edzw',
'GRIB_centreDescription': 'Offenbach',
'GRIB_edition': 2,
'institution': 'Offenbach'
}
)
ds.coords['valid_time'] = ds.time + ds.step
zarr_path = session_tmp_path / f"{time}.zarr"
ds.to_zarr(zarr_path)
paths.append(zarr_path)

return paths


@pytest.fixture(scope="session")
def ds_uk_gsp():
Expand Down
8 changes: 8 additions & 0 deletions tests/load/test_load_nwp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,12 @@ def test_load_ecmwf(nwp_ecmwf_zarr_path):
assert isinstance(da, DataArray)
assert da.dims == ("init_time_utc", "step", "channel", "longitude", "latitude")
assert da.shape == (24 * 7, 15, 3, 15, 12)
assert np.issubdtype(da.dtype, np.number)


def test_load_icon_eu(icon_eu_zarr_path):
da = open_nwp(zarr_path=icon_eu_zarr_path, provider="icon-eu")
assert isinstance(da, DataArray)
assert da.dims == ("init_time_utc", "step", "channel", "latitude", "longitude")
assert da.shape == (2, 78, 2, 100, 100)
assert np.issubdtype(da.dtype, np.number)
76 changes: 76 additions & 0 deletions utils/compute_icon_mean_stddev.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
""" Script to compute normalisation constants from NWP data """

import xarray as xr
import numpy as np
import glob
import argparse

from ocf_data_sampler.load.nwp.providers.icon import open_icon_eu

# Add argument parser
parser = argparse.ArgumentParser(description='Compute normalization constants from NWP data')
parser.add_argument('--data-path', type=str, required=True,
help='Path pattern to zarr files (e.g., "/path/to/data/*.zarr.zip")')
parser.add_argument('--n-samples', type=int, default=2000,
help='Number of random samples to use (default: 2000)')

args = parser.parse_args()

zarr_files = glob.glob(args.data_path)
n_samples = args.n_samples

ds = open_icon_eu(zarr_files)

n_init_times = ds.sizes['init_time_utc']
n_lats = ds.sizes['latitude']
n_longs = ds.sizes['longitude']
n_steps = ds.sizes['step']

random_init_times = np.random.choice(n_init_times, size=n_samples, replace=True)
random_lats = np.random.choice(n_lats, size=n_samples, replace=True)
random_longs = np.random.choice(n_longs, size=n_samples, replace=True)
random_steps = np.random.choice(n_steps, size=n_samples, replace=True)

samples = []
for i in range(n_samples):
sample = ds.isel(init_time_utc=random_init_times[i],
latitude=random_lats[i],
longitude=random_longs[i],
step=random_steps[i])
samples.append(sample)

samples_stack = xr.concat(samples, dim='samples')


# variables = [
# "alb_rad", "aswdifd_s", "aswdir_s", "cape_con", "clch", "clcl", "clcm",
# "clct", "h_snow", "omega", "pmsl", "relhum_2m", "runoff_g", "runoff_s",
# "t", "t_2m", "t_g", "td_2m", "tot_prec", "u", "u_10m", "v", "v_10m",
# "vmax_10m", "w_snow", "ww", "z0"
# ]
print(samples_stack)

available_channels = samples_stack.channel.values.tolist()
print("Available channels: ", available_channels)

ICON_EU_MEAN = {}
ICON_EU_STD = {}

for var in available_channels:
if var not in available_channels:
print(f"Warning: Variable '{var}' not found in the channel coordinate; skipping.")
continue
var_data = samples_stack.sel(channel=var)
var_mean = float(var_data.mean().compute())
var_std = float(var_data.std().compute())

ICON_EU_MEAN[var] = var_mean
ICON_EU_STD[var] = var_std

print(f"Processed {var}: mean={var_mean:.4f}, std={var_std:.4f}")

print("\nMean values:")
print(ICON_EU_MEAN)
print("\nStandard deviations:")
print(ICON_EU_STD)