diff --git a/Dockerfile b/Dockerfile
index f26ade7..02dad17 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -2,7 +2,7 @@ FROM python:3.11-slim
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
RUN apt-get update -y && \
- apt-get -y --no-install-recommends install git unzip build-essential && \
+ apt-get -y --no-install-recommends install git build-essential && \
rm -rf /var/lib/apt/lists/*
# Add repository code
@@ -13,5 +13,7 @@ COPY .git /opt/app/.git
RUN uv sync --no-dev
+ENV _TYPER_STANDARD_TRACEBACK=1
+
ENTRYPOINT ["uv", "run", "--no-dev"]
CMD ["cloudcasting-app"]
diff --git a/pyproject.toml b/pyproject.toml
index 34fe4ec..0ecb483 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -23,6 +23,7 @@ dependencies = [
"fsspec==2024.6.1",
"huggingface-hub==0.28.1",
"hydra-core==1.3.2",
+ "loguru == 0.7.3",
"numpy==2.1.2",
"ocf-data-sampler==0.1.4",
"ocf_blosc2==0.0.13",
diff --git a/src/cloudcasting_app/__init__.py b/src/cloudcasting_app/__init__.py
index e69de29..57ed5c0 100644
--- a/src/cloudcasting_app/__init__.py
+++ b/src/cloudcasting_app/__init__.py
@@ -0,0 +1,51 @@
+"""Setup logging configuration for the application."""
+
+import json
+import sys
+import os
+import loguru
+
+def development_formatter(record: "loguru.Record") -> str:
+ """Format a log record for development."""
+ return "".join((
+ "{time:HH:mm:ss.SSS} ",
+ "{level:<7s} [{file}:{line}] | {message} ",
+ "{extra}" if record["extra"] else "",
+ "\n{exception}",
+ ))
+
+def structured_formatter(record: "loguru.Record") -> str:
+ """Format a log record as a structured JSON object."""
+ record["extra"]["serialized"] = json.dumps({
+ "timestamp": record["time"].strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
+ "severity": record["level"].name,
+ "message": record["message"],
+ "logging.googleapis.com/labels": {"python_logger": record["name"]},
+ "logging.googleapis.com/sourceLocation": {
+ "file": record["file"].name,
+ "line": record["line"],
+ "function": record["function"],
+ },
+ "context": record["extra"],
+ })
+ return "{extra[serialized]}\n"
+
+# Define the logging formatter, removing the default one
+loguru.logger.remove(0)
+if sys.stdout.isatty():
+ # Simple logging for development
+ loguru.logger.add(
+ sys.stdout, format=development_formatter, diagnose=True,
+ level=os.getenv("LOGLEVEL", "DEBUG"), backtrace=True, colorize=True,
+ )
+else:
+ # JSON logging for containers
+ loguru.logger.add(
+ sys.stdout, format=structured_formatter, backtrace=True,
+ level=os.getenv("LOGLEVEL", "INFO").upper(),
+ )
+
+# Uncomment and change the list to quieten external libraries
+# for logger in ["aiobotocore", "cfgrib"]:
+# logging.getLogger(logger).setLevel(logging.WARNING)
+
diff --git a/src/cloudcasting_app/app.py b/src/cloudcasting_app/app.py
index f023e8d..d8d9427 100644
--- a/src/cloudcasting_app/app.py
+++ b/src/cloudcasting_app/app.py
@@ -1,27 +1,25 @@
-"""
-The main script for running the cloudcasting model in production
+"""The main script for running the cloudcasting model in production
This app expects these environmental variables to be available:
SATELLITE_ZARR_PATH (str): The path of the input satellite data
OUTPUT_PREDICTION_DIRECTORY (str): The path of the directory to save the predictions to
"""
-from importlib.metadata import PackageNotFoundError, version
-import logging
import os
-import yaml
-import hydra
-import typer
-import fsspec
+from importlib.metadata import PackageNotFoundError, version
+import fsspec
+import hydra
import pandas as pd
-import xarray as xr
import torch
-
+import typer
+import xarray as xr
+import yaml
from huggingface_hub import snapshot_download
from safetensors.torch import load_model
+from loguru import logger
-from cloudcasting_app.data import prepare_satellite_data, sat_path, get_input_data
+from cloudcasting_app.data import get_input_data, prepare_satellite_data, sat_path
# Get package version
try:
@@ -30,15 +28,6 @@
__version__ = "v?"
# ---------------------------------------------------------------------------
-# GLOBAL SETTINGS
-
-logging.basicConfig(
- level=getattr(logging, os.getenv("LOGLEVEL", "INFO")),
- format="[%(asctime)s] {%(pathname)s:%(lineno)d} %(levelname)s - %(message)s",
-)
-
-# Create a logger
-logger = logging.getLogger(__name__)
# Model will use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -56,9 +45,8 @@ def app(t0=None):
Args:
t0 (datetime): Datetime at which forecast is made
"""
-
- logger.info(f"Using `cloudcasting_app` version: {__version__}")
-
+ logger.info(f"Using `cloudcasting_app` version: {__version__}", version=__version__)
+
# ---------------------------------------------------------------------------
# 0. If inference datetime is None, round down to last 30 minutes
if t0 is None:
@@ -72,87 +60,87 @@ def app(t0=None):
# 1. Prepare the input data
logger.info("Downloading satellite data")
prepare_satellite_data(t0)
-
+
# ---------------------------------------------------------------------------
# 2. Load model
logger.info("Loading model")
-
+
hf_download_dir = snapshot_download(
repo_id=REPO_ID,
revision=REVISION,
)
-
- with open(f"{hf_download_dir}/model_config.yaml", "r", encoding="utf-8") as f:
+
+ with open(f"{hf_download_dir}/model_config.yaml", encoding="utf-8") as f:
model = hydra.utils.instantiate(yaml.safe_load(f))
-
+
model = model.to(device)
load_model(
model,
- filename=f"{hf_download_dir}/model.safetensors",
+ filename=f"{hf_download_dir}/model.safetensors",
strict=True,
)
-
+
model.eval()
-
+
# ---------------------------------------------------------------------------
# 3. Get inference inputs
logger.info("Preparing inputs")
-
+
# TODO check the spatial dimensions of this zarr
# Get inputs
ds = xr.open_zarr(sat_path)
-
+
# Reshape to (channel, time, height, width)
ds = ds.transpose("variable", "time", "y_geostationary", "x_geostationary")
-
+
X = get_input_data(ds, t0)
-
+
# Convert to tensor, expand into batch dimension, and move to device
X = X[None, ...].to(device)
-
+
# ---------------------------------------------------------------------------
# 4. Make predictions
logger.info("Making predictions")
-
+
with torch.no_grad():
y_hat = model(X).cpu().numpy()
-
+
# ---------------------------------------------------------------------------
# 5. Save predictions
logger.info("Saving predictions")
da_y_hat = xr.DataArray(
- y_hat,
- dims=["init_time", "variable", "step", "y_geostationary", "x_geostationary"],
+ y_hat,
+ dims=["init_time", "variable", "step", "y_geostationary", "x_geostationary"],
coords={
"init_time": [t0],
"variable": ds.variable,
"step": pd.timedelta_range(start="15min", end="180min", freq="15min"),
"y_geostationary": ds.y_geostationary,
"x_geostationary": ds.x_geostationary,
- }
+ },
)
-
+
ds_y_hat = da_y_hat.to_dataset(name="sat_pred")
ds_y_hat.sat_pred.attrs.update(ds.data.attrs)
-
+
# Save predictions to the latest path and to path with timestring
out_dir = os.environ["OUTPUT_PREDICTION_DIRECTORY"]
-
+
latest_zarr_path = f"{out_dir}/latest.zarr"
t0_string_zarr_path = t0.strftime(f"{out_dir}/%Y-%m-%dT%H:%M.zarr")
-
+
fs = fsspec.open(out_dir).fs
for path in [latest_zarr_path, t0_string_zarr_path]:
-
+
# Remove the path if it exists already
if fs.exists(path):
logger.info(f"Removing path: {path}")
fs.rm(path, recursive=True)
-
+
ds_y_hat.to_zarr(path)
-
-
+
+
def main() -> None:
"""Entrypoint to the application."""
typer.run(app)
diff --git a/src/cloudcasting_app/data.py b/src/cloudcasting_app/data.py
index 0164319..7f6a4c6 100644
--- a/src/cloudcasting_app/data.py
+++ b/src/cloudcasting_app/data.py
@@ -1,14 +1,15 @@
-import numpy as np
-import pandas as pd
-import xarray as xr
-import torch
import logging
+import shutil
import os
+
import fsspec
-import ocf_blosc2
+import numpy as np
+import pandas as pd
+import zipfile
+import torch
+import xarray as xr
from ocf_data_sampler.select.geospatial import lon_lat_to_geostationary_area_coords
-
logger = logging.getLogger(__name__)
sat_5_path = "sat_5_min.zarr"
@@ -23,17 +24,17 @@
channel_order = [
- 'IR_016',
- 'IR_039',
- 'IR_087',
- 'IR_097',
- 'IR_108',
- 'IR_120',
- 'IR_134',
- 'VIS006',
- 'VIS008',
- 'WV_062',
- 'WV_073',
+ "IR_016",
+ "IR_039",
+ "IR_087",
+ "IR_097",
+ "IR_108",
+ "IR_120",
+ "IR_134",
+ "VIS006",
+ "VIS008",
+ "WV_062",
+ "WV_073",
]
@@ -43,50 +44,51 @@ def crop_input_area(ds):
[lat_min, lat_max],
ds.data,
)
-
+
ds = ds.isel(x_geostationary=slice(None, None, -1)) # x-axis is in decreasing order
ds = ds.sel(
- x_geostationary=slice(x_min, None),
+ x_geostationary=slice(x_min, None),
y_geostationary=slice(y_min, None),
).isel(
x_geostationary=slice(0,614),
y_geostationary=slice(0,372),
)
-
+
ds = ds.isel(x_geostationary=slice(None, None, -1)) # flip back
assert len(ds.x_geostationary)==614
assert len(ds.y_geostationary)==372
-
+
return ds
def prepare_satellite_data(t0: pd.Timestamp):
-
+
# Download the 5 and/or 15 minutely satellite data
download_all_sat_data()
-
+
# Select between the 5/15 minute satellite data sources
combine_5_and_15_sat_data()
-
+
# Check the required expected timestamps are available
check_required_timestamps_available(t0)
-
+
# Load data the data for more preprocessing
ds = xr.open_zarr(sat_path)
-
+
# Crop the input area to expected
ds = crop_input_area(ds)
-
+
# Reorder channels
ds = ds.sel(variable=channel_order)
-
+
# Scale the satellite data from 0-1
ds = ds / 1023
-
+
# Resave
ds = ds.compute()
- os.system(f"rm -rf {sat_path}")
+ if os.path.exists(sat_path):
+ shutil.rmtree(sat_path)
ds.to_zarr(sat_path)
@@ -96,9 +98,11 @@ def download_all_sat_data() -> bool:
Returns:
bool: Whether the download was successful
"""
-
# Clean out old files
- os.system(f"rm -r {sat_path} {sat_5_path} {sat_15_path}")
+ logging.debug("Cleaning out old satellite data")
+ for loc in [sat_path, sat_5_path, sat_15_path]:
+ if os.path.exists(loc):
+ shutil.rmtree(loc)
# Set variable to track whether the satellite download is successful
sat_available = False
@@ -108,42 +112,44 @@ def download_all_sat_data() -> bool:
fs, _ = fsspec.core.url_to_fs(sat_5_dl_path)
if fs.exists(sat_5_dl_path):
sat_available = True
- logger.info(f"Downloading 5-minute satellite data")
+ logger.info("Downloading 5-minute satellite data")
fs.get(sat_5_dl_path, "sat_5_min.zarr.zip")
- os.system(f"unzip -qq sat_5_min.zarr.zip -d {sat_5_path}")
- os.system(f"rm sat_5_min.zarr.zip")
+ with zipfile.ZipFile("sat_5_min.zarr.zip", "r") as zip_ref:
+ zip_ref.extractall(sat_5_path)
+ os.remove("sat_5_min.zarr.zip")
else:
- logger.info(f"No 5-minute data available")
+ logger.info("No 5-minute data available")
# Also download 15-minute satellite if it exists
- sat_15_dl_path = os.environ["SATELLITE_ZARR_PATH"].replace(".zarr", "_15.zarr")
+ sat_15_dl_path = sat_5_dl_path.replace(".zarr", "_15.zarr")
if fs.exists(sat_15_dl_path):
sat_available = True
- logger.info(f"Downloading 15-minute satellite data")
+ logger.info("Downloading 15-minute satellite data")
fs.get(sat_15_dl_path, "sat_15_min.zarr.zip")
- os.system(f"unzip -qq sat_15_min.zarr.zip -d {sat_15_path}")
- os.system(f"rm sat_15_min.zarr.zip")
+ with zipfile.ZipFile("sat_15_min.zarr.zip", "r") as zip_ref:
+ zip_ref.extractall(sat_15_path)
+ os.remove("sat_15_min.zarr.zip")
else:
- logger.info(f"No 15-minute data available")
+ logger.info("No 15-minute data available")
return sat_available
def check_required_timestamps_available(t0: pd.Timestamp):
available_timestamps = get_satellite_timestamps(sat_path)
-
+
# Need 12 timestamps of 15 minutely data up to and including time t0
expected_timestamps = pd.date_range(t0-pd.Timedelta("165min"), t0, freq="15min")
-
+
timestamps_available = np.isin(expected_timestamps, available_timestamps)
-
+
if not timestamps_available.all():
missing_timestamps = expected_timestamps[~timestamps_available]
raise Exception(
"Some required timestamps missing\n"
f"Required timestamps: {expected_timestamps}\n"
f"Available timestamps: {timestamps_available}\n"
- f"Missing timestamps: {missing_timestamps}"
+ f"Missing timestamps: {missing_timestamps}",
)
@@ -162,7 +168,6 @@ def get_satellite_timestamps(sat_zarr_path: str) -> pd.DatetimeIndex:
def combine_5_and_15_sat_data() -> None:
"""Select and/or combine the 5 and 15-minutely satellite data and move it to the expected path"""
-
# Check which satellite data exists
exists_5_minute = os.path.exists(sat_5_path)
exists_15_minute = os.path.exists(sat_15_path)
@@ -175,7 +180,7 @@ def combine_5_and_15_sat_data() -> None:
datetimes_5min = get_satellite_timestamps(sat_5_path)
logger.info(
f"Latest 5-minute timestamp is {datetimes_5min.max()}. "
- f"All the datetimes are: \n{datetimes_5min}"
+ f"All the datetimes are: \n{datetimes_5min}",
)
else:
logger.info("No 5-minute data was found.")
@@ -184,7 +189,7 @@ def combine_5_and_15_sat_data() -> None:
datetimes_15min = get_satellite_timestamps(sat_15_path)
logger.info(
f"Latest 5-minute timestamp is {datetimes_15min.max()}. "
- f"All the datetimes are: \n{datetimes_15min}"
+ f"All the datetimes are: \n{datetimes_15min}",
)
else:
logger.info("No 15-minute data was found.")
@@ -198,26 +203,26 @@ def combine_5_and_15_sat_data() -> None:
# Move the selected data to the expected path
if use_5_minute:
- logger.info(f"Using 5-minutely data.")
+ logger.info("Using 5-minutely data.")
os.system(f"mv {sat_5_path} {sat_path}")
else:
- logger.info(f"Using 15-minutely data.")
+ logger.info("Using 15-minutely data.")
os.system(f"mv {sat_15_path} {sat_path}")
-
+
def get_input_data(ds: xr.Dataset, t0: pd.Timestamp):
-
+
# Slice the data
required_timestamps = pd.date_range(t0-pd.Timedelta("165min"), t0, freq="15min")
ds_sel = ds.reindex(time=required_timestamps)
# Load the data
ds_sel = ds_sel.compute(scheduler="single-threaded")
-
+
# Convert to arrays
X = ds_sel.data.values.astype(np.float32)
# Convert NaNs to -1
X = np.nan_to_num(X, nan=-1)
- return torch.Tensor(X)
\ No newline at end of file
+ return torch.Tensor(X)
diff --git a/tests/conftest.py b/tests/conftest.py
index a2f71d3..40b65bb 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,11 +1,10 @@
import os
-import pytest
+import fsspec
import numpy as np
import pandas as pd
+import pytest
import xarray as xr
-import fsspec
-
xr.set_options(keep_attrs=True)
@@ -15,16 +14,16 @@ def test_t0():
def make_sat_data(test_t0, freq_mins):
-
+
# Load dataset which only contains coordinates, but no data
shell_path = f"{os.path.dirname(os.path.abspath(__file__))}/test_data/non_hrv_shell.zarr.zip"
-
+
ds = xr.open_zarr(fsspec.get_mapper(f"zip::{shell_path}"))
- # Remove original time dim
+ # Remove original time dim
ds = ds.drop_vars("time")
- # Add new times so they lead up to present
+ # Add new times so they lead up to present
times = pd.date_range(
test_t0 - pd.Timedelta("3h"),
test_t0,
@@ -37,7 +36,7 @@ def make_sat_data(test_t0, freq_mins):
np.zeros([len(ds[c]) for c in ds.xindexes]),
coords=[ds[c] for c in ds.xindexes],
)
-
+
# Add stored attributes to DataArray
ds.data.attrs = ds.attrs["_data_attrs"]
del ds.attrs["_data_attrs"]
@@ -47,4 +46,4 @@ def make_sat_data(test_t0, freq_mins):
@pytest.fixture()
def sat_5_data(test_t0):
- return make_sat_data(test_t0, freq_mins=5)
\ No newline at end of file
+ return make_sat_data(test_t0, freq_mins=5)
diff --git a/tests/test_app.py b/tests/test_app.py
index 5fad7ff..c763ed7 100644
--- a/tests/test_app.py
+++ b/tests/test_app.py
@@ -1,8 +1,8 @@
-import zarr
import os
-import xarray as xr
-import numpy as np
+import numpy as np
+import xarray as xr
+import zarr
from cloudcasting_app.app import app
@@ -17,30 +17,30 @@ def test_app(sat_5_data, tmp_path, test_t0):
os.environ["OUTPUT_PREDICTION_DIRECTORY"] = f"{tmp_path}"
with zarr.storage.ZipStore("temp_sat.zarr.zip", mode="x") as store:
- sat_5_data.to_zarr(store)
+ sat_5_data.to_zarr(store)
app()
-
+
# Check the two output files have been created
latest_zarr_path = f"{tmp_path}/latest.zarr"
t0_string_zarr_path = test_t0.strftime(f"{tmp_path}/%Y-%m-%dT%H:%M.zarr")
assert os.path.exists(latest_zarr_path)
assert os.path.exists(t0_string_zarr_path)
-
+
# Load the predictions and check them
ds_y_hat = xr.open_zarr(latest_zarr_path)
-
+
assert "sat_pred" in ds_y_hat
assert (
list(ds_y_hat.sat_pred.coords)==
["init_time", "step", "variable", "x_geostationary", "y_geostationary"]
)
-
+
# Make sure all the coords are correct
assert ds_y_hat.init_time == test_t0
assert len(ds_y_hat.step)==12
assert (ds_y_hat.x_geostationary==sat_5_data.x_geostationary).all()
assert (ds_y_hat.y_geostationary==sat_5_data.y_geostationary).all()
-
+
# Make sure all of the predictions are finite
- assert np.isfinite(ds_y_hat.sat_pred).all()
\ No newline at end of file
+ assert np.isfinite(ds_y_hat.sat_pred).all()