diff --git a/Dockerfile b/Dockerfile index f26ade7..02dad17 100644 --- a/Dockerfile +++ b/Dockerfile @@ -2,7 +2,7 @@ FROM python:3.11-slim COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/ RUN apt-get update -y && \ - apt-get -y --no-install-recommends install git unzip build-essential && \ + apt-get -y --no-install-recommends install git build-essential && \ rm -rf /var/lib/apt/lists/* # Add repository code @@ -13,5 +13,7 @@ COPY .git /opt/app/.git RUN uv sync --no-dev +ENV _TYPER_STANDARD_TRACEBACK=1 + ENTRYPOINT ["uv", "run", "--no-dev"] CMD ["cloudcasting-app"] diff --git a/pyproject.toml b/pyproject.toml index 34fe4ec..0ecb483 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ dependencies = [ "fsspec==2024.6.1", "huggingface-hub==0.28.1", "hydra-core==1.3.2", + "loguru == 0.7.3", "numpy==2.1.2", "ocf-data-sampler==0.1.4", "ocf_blosc2==0.0.13", diff --git a/src/cloudcasting_app/__init__.py b/src/cloudcasting_app/__init__.py index e69de29..57ed5c0 100644 --- a/src/cloudcasting_app/__init__.py +++ b/src/cloudcasting_app/__init__.py @@ -0,0 +1,51 @@ +"""Setup logging configuration for the application.""" + +import json +import sys +import os +import loguru + +def development_formatter(record: "loguru.Record") -> str: + """Format a log record for development.""" + return "".join(( + "{time:HH:mm:ss.SSS} ", + "{level:<7s} [{file}:{line}] | {message} ", + "{extra}" if record["extra"] else "", + "\n{exception}", + )) + +def structured_formatter(record: "loguru.Record") -> str: + """Format a log record as a structured JSON object.""" + record["extra"]["serialized"] = json.dumps({ + "timestamp": record["time"].strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "severity": record["level"].name, + "message": record["message"], + "logging.googleapis.com/labels": {"python_logger": record["name"]}, + "logging.googleapis.com/sourceLocation": { + "file": record["file"].name, + "line": record["line"], + "function": record["function"], + }, + "context": record["extra"], + }) + return "{extra[serialized]}\n" + +# Define the logging formatter, removing the default one +loguru.logger.remove(0) +if sys.stdout.isatty(): + # Simple logging for development + loguru.logger.add( + sys.stdout, format=development_formatter, diagnose=True, + level=os.getenv("LOGLEVEL", "DEBUG"), backtrace=True, colorize=True, + ) +else: + # JSON logging for containers + loguru.logger.add( + sys.stdout, format=structured_formatter, backtrace=True, + level=os.getenv("LOGLEVEL", "INFO").upper(), + ) + +# Uncomment and change the list to quieten external libraries +# for logger in ["aiobotocore", "cfgrib"]: +# logging.getLogger(logger).setLevel(logging.WARNING) + diff --git a/src/cloudcasting_app/app.py b/src/cloudcasting_app/app.py index f023e8d..d8d9427 100644 --- a/src/cloudcasting_app/app.py +++ b/src/cloudcasting_app/app.py @@ -1,27 +1,25 @@ -""" -The main script for running the cloudcasting model in production +"""The main script for running the cloudcasting model in production This app expects these environmental variables to be available: SATELLITE_ZARR_PATH (str): The path of the input satellite data OUTPUT_PREDICTION_DIRECTORY (str): The path of the directory to save the predictions to """ -from importlib.metadata import PackageNotFoundError, version -import logging import os -import yaml -import hydra -import typer -import fsspec +from importlib.metadata import PackageNotFoundError, version +import fsspec +import hydra import pandas as pd -import xarray as xr import torch - +import typer +import xarray as xr +import yaml from huggingface_hub import snapshot_download from safetensors.torch import load_model +from loguru import logger -from cloudcasting_app.data import prepare_satellite_data, sat_path, get_input_data +from cloudcasting_app.data import get_input_data, prepare_satellite_data, sat_path # Get package version try: @@ -30,15 +28,6 @@ __version__ = "v?" # --------------------------------------------------------------------------- -# GLOBAL SETTINGS - -logging.basicConfig( - level=getattr(logging, os.getenv("LOGLEVEL", "INFO")), - format="[%(asctime)s] {%(pathname)s:%(lineno)d} %(levelname)s - %(message)s", -) - -# Create a logger -logger = logging.getLogger(__name__) # Model will use GPU if available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -56,9 +45,8 @@ def app(t0=None): Args: t0 (datetime): Datetime at which forecast is made """ - - logger.info(f"Using `cloudcasting_app` version: {__version__}") - + logger.info(f"Using `cloudcasting_app` version: {__version__}", version=__version__) + # --------------------------------------------------------------------------- # 0. If inference datetime is None, round down to last 30 minutes if t0 is None: @@ -72,87 +60,87 @@ def app(t0=None): # 1. Prepare the input data logger.info("Downloading satellite data") prepare_satellite_data(t0) - + # --------------------------------------------------------------------------- # 2. Load model logger.info("Loading model") - + hf_download_dir = snapshot_download( repo_id=REPO_ID, revision=REVISION, ) - - with open(f"{hf_download_dir}/model_config.yaml", "r", encoding="utf-8") as f: + + with open(f"{hf_download_dir}/model_config.yaml", encoding="utf-8") as f: model = hydra.utils.instantiate(yaml.safe_load(f)) - + model = model.to(device) load_model( model, - filename=f"{hf_download_dir}/model.safetensors", + filename=f"{hf_download_dir}/model.safetensors", strict=True, ) - + model.eval() - + # --------------------------------------------------------------------------- # 3. Get inference inputs logger.info("Preparing inputs") - + # TODO check the spatial dimensions of this zarr # Get inputs ds = xr.open_zarr(sat_path) - + # Reshape to (channel, time, height, width) ds = ds.transpose("variable", "time", "y_geostationary", "x_geostationary") - + X = get_input_data(ds, t0) - + # Convert to tensor, expand into batch dimension, and move to device X = X[None, ...].to(device) - + # --------------------------------------------------------------------------- # 4. Make predictions logger.info("Making predictions") - + with torch.no_grad(): y_hat = model(X).cpu().numpy() - + # --------------------------------------------------------------------------- # 5. Save predictions logger.info("Saving predictions") da_y_hat = xr.DataArray( - y_hat, - dims=["init_time", "variable", "step", "y_geostationary", "x_geostationary"], + y_hat, + dims=["init_time", "variable", "step", "y_geostationary", "x_geostationary"], coords={ "init_time": [t0], "variable": ds.variable, "step": pd.timedelta_range(start="15min", end="180min", freq="15min"), "y_geostationary": ds.y_geostationary, "x_geostationary": ds.x_geostationary, - } + }, ) - + ds_y_hat = da_y_hat.to_dataset(name="sat_pred") ds_y_hat.sat_pred.attrs.update(ds.data.attrs) - + # Save predictions to the latest path and to path with timestring out_dir = os.environ["OUTPUT_PREDICTION_DIRECTORY"] - + latest_zarr_path = f"{out_dir}/latest.zarr" t0_string_zarr_path = t0.strftime(f"{out_dir}/%Y-%m-%dT%H:%M.zarr") - + fs = fsspec.open(out_dir).fs for path in [latest_zarr_path, t0_string_zarr_path]: - + # Remove the path if it exists already if fs.exists(path): logger.info(f"Removing path: {path}") fs.rm(path, recursive=True) - + ds_y_hat.to_zarr(path) - - + + def main() -> None: """Entrypoint to the application.""" typer.run(app) diff --git a/src/cloudcasting_app/data.py b/src/cloudcasting_app/data.py index 0164319..7f6a4c6 100644 --- a/src/cloudcasting_app/data.py +++ b/src/cloudcasting_app/data.py @@ -1,14 +1,15 @@ -import numpy as np -import pandas as pd -import xarray as xr -import torch import logging +import shutil import os + import fsspec -import ocf_blosc2 +import numpy as np +import pandas as pd +import zipfile +import torch +import xarray as xr from ocf_data_sampler.select.geospatial import lon_lat_to_geostationary_area_coords - logger = logging.getLogger(__name__) sat_5_path = "sat_5_min.zarr" @@ -23,17 +24,17 @@ channel_order = [ - 'IR_016', - 'IR_039', - 'IR_087', - 'IR_097', - 'IR_108', - 'IR_120', - 'IR_134', - 'VIS006', - 'VIS008', - 'WV_062', - 'WV_073', + "IR_016", + "IR_039", + "IR_087", + "IR_097", + "IR_108", + "IR_120", + "IR_134", + "VIS006", + "VIS008", + "WV_062", + "WV_073", ] @@ -43,50 +44,51 @@ def crop_input_area(ds): [lat_min, lat_max], ds.data, ) - + ds = ds.isel(x_geostationary=slice(None, None, -1)) # x-axis is in decreasing order ds = ds.sel( - x_geostationary=slice(x_min, None), + x_geostationary=slice(x_min, None), y_geostationary=slice(y_min, None), ).isel( x_geostationary=slice(0,614), y_geostationary=slice(0,372), ) - + ds = ds.isel(x_geostationary=slice(None, None, -1)) # flip back assert len(ds.x_geostationary)==614 assert len(ds.y_geostationary)==372 - + return ds def prepare_satellite_data(t0: pd.Timestamp): - + # Download the 5 and/or 15 minutely satellite data download_all_sat_data() - + # Select between the 5/15 minute satellite data sources combine_5_and_15_sat_data() - + # Check the required expected timestamps are available check_required_timestamps_available(t0) - + # Load data the data for more preprocessing ds = xr.open_zarr(sat_path) - + # Crop the input area to expected ds = crop_input_area(ds) - + # Reorder channels ds = ds.sel(variable=channel_order) - + # Scale the satellite data from 0-1 ds = ds / 1023 - + # Resave ds = ds.compute() - os.system(f"rm -rf {sat_path}") + if os.path.exists(sat_path): + shutil.rmtree(sat_path) ds.to_zarr(sat_path) @@ -96,9 +98,11 @@ def download_all_sat_data() -> bool: Returns: bool: Whether the download was successful """ - # Clean out old files - os.system(f"rm -r {sat_path} {sat_5_path} {sat_15_path}") + logging.debug("Cleaning out old satellite data") + for loc in [sat_path, sat_5_path, sat_15_path]: + if os.path.exists(loc): + shutil.rmtree(loc) # Set variable to track whether the satellite download is successful sat_available = False @@ -108,42 +112,44 @@ def download_all_sat_data() -> bool: fs, _ = fsspec.core.url_to_fs(sat_5_dl_path) if fs.exists(sat_5_dl_path): sat_available = True - logger.info(f"Downloading 5-minute satellite data") + logger.info("Downloading 5-minute satellite data") fs.get(sat_5_dl_path, "sat_5_min.zarr.zip") - os.system(f"unzip -qq sat_5_min.zarr.zip -d {sat_5_path}") - os.system(f"rm sat_5_min.zarr.zip") + with zipfile.ZipFile("sat_5_min.zarr.zip", "r") as zip_ref: + zip_ref.extractall(sat_5_path) + os.remove("sat_5_min.zarr.zip") else: - logger.info(f"No 5-minute data available") + logger.info("No 5-minute data available") # Also download 15-minute satellite if it exists - sat_15_dl_path = os.environ["SATELLITE_ZARR_PATH"].replace(".zarr", "_15.zarr") + sat_15_dl_path = sat_5_dl_path.replace(".zarr", "_15.zarr") if fs.exists(sat_15_dl_path): sat_available = True - logger.info(f"Downloading 15-minute satellite data") + logger.info("Downloading 15-minute satellite data") fs.get(sat_15_dl_path, "sat_15_min.zarr.zip") - os.system(f"unzip -qq sat_15_min.zarr.zip -d {sat_15_path}") - os.system(f"rm sat_15_min.zarr.zip") + with zipfile.ZipFile("sat_15_min.zarr.zip", "r") as zip_ref: + zip_ref.extractall(sat_15_path) + os.remove("sat_15_min.zarr.zip") else: - logger.info(f"No 15-minute data available") + logger.info("No 15-minute data available") return sat_available def check_required_timestamps_available(t0: pd.Timestamp): available_timestamps = get_satellite_timestamps(sat_path) - + # Need 12 timestamps of 15 minutely data up to and including time t0 expected_timestamps = pd.date_range(t0-pd.Timedelta("165min"), t0, freq="15min") - + timestamps_available = np.isin(expected_timestamps, available_timestamps) - + if not timestamps_available.all(): missing_timestamps = expected_timestamps[~timestamps_available] raise Exception( "Some required timestamps missing\n" f"Required timestamps: {expected_timestamps}\n" f"Available timestamps: {timestamps_available}\n" - f"Missing timestamps: {missing_timestamps}" + f"Missing timestamps: {missing_timestamps}", ) @@ -162,7 +168,6 @@ def get_satellite_timestamps(sat_zarr_path: str) -> pd.DatetimeIndex: def combine_5_and_15_sat_data() -> None: """Select and/or combine the 5 and 15-minutely satellite data and move it to the expected path""" - # Check which satellite data exists exists_5_minute = os.path.exists(sat_5_path) exists_15_minute = os.path.exists(sat_15_path) @@ -175,7 +180,7 @@ def combine_5_and_15_sat_data() -> None: datetimes_5min = get_satellite_timestamps(sat_5_path) logger.info( f"Latest 5-minute timestamp is {datetimes_5min.max()}. " - f"All the datetimes are: \n{datetimes_5min}" + f"All the datetimes are: \n{datetimes_5min}", ) else: logger.info("No 5-minute data was found.") @@ -184,7 +189,7 @@ def combine_5_and_15_sat_data() -> None: datetimes_15min = get_satellite_timestamps(sat_15_path) logger.info( f"Latest 5-minute timestamp is {datetimes_15min.max()}. " - f"All the datetimes are: \n{datetimes_15min}" + f"All the datetimes are: \n{datetimes_15min}", ) else: logger.info("No 15-minute data was found.") @@ -198,26 +203,26 @@ def combine_5_and_15_sat_data() -> None: # Move the selected data to the expected path if use_5_minute: - logger.info(f"Using 5-minutely data.") + logger.info("Using 5-minutely data.") os.system(f"mv {sat_5_path} {sat_path}") else: - logger.info(f"Using 15-minutely data.") + logger.info("Using 15-minutely data.") os.system(f"mv {sat_15_path} {sat_path}") - + def get_input_data(ds: xr.Dataset, t0: pd.Timestamp): - + # Slice the data required_timestamps = pd.date_range(t0-pd.Timedelta("165min"), t0, freq="15min") ds_sel = ds.reindex(time=required_timestamps) # Load the data ds_sel = ds_sel.compute(scheduler="single-threaded") - + # Convert to arrays X = ds_sel.data.values.astype(np.float32) # Convert NaNs to -1 X = np.nan_to_num(X, nan=-1) - return torch.Tensor(X) \ No newline at end of file + return torch.Tensor(X) diff --git a/tests/conftest.py b/tests/conftest.py index a2f71d3..40b65bb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,11 +1,10 @@ import os -import pytest +import fsspec import numpy as np import pandas as pd +import pytest import xarray as xr -import fsspec - xr.set_options(keep_attrs=True) @@ -15,16 +14,16 @@ def test_t0(): def make_sat_data(test_t0, freq_mins): - + # Load dataset which only contains coordinates, but no data shell_path = f"{os.path.dirname(os.path.abspath(__file__))}/test_data/non_hrv_shell.zarr.zip" - + ds = xr.open_zarr(fsspec.get_mapper(f"zip::{shell_path}")) - # Remove original time dim + # Remove original time dim ds = ds.drop_vars("time") - # Add new times so they lead up to present + # Add new times so they lead up to present times = pd.date_range( test_t0 - pd.Timedelta("3h"), test_t0, @@ -37,7 +36,7 @@ def make_sat_data(test_t0, freq_mins): np.zeros([len(ds[c]) for c in ds.xindexes]), coords=[ds[c] for c in ds.xindexes], ) - + # Add stored attributes to DataArray ds.data.attrs = ds.attrs["_data_attrs"] del ds.attrs["_data_attrs"] @@ -47,4 +46,4 @@ def make_sat_data(test_t0, freq_mins): @pytest.fixture() def sat_5_data(test_t0): - return make_sat_data(test_t0, freq_mins=5) \ No newline at end of file + return make_sat_data(test_t0, freq_mins=5) diff --git a/tests/test_app.py b/tests/test_app.py index 5fad7ff..c763ed7 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -1,8 +1,8 @@ -import zarr import os -import xarray as xr -import numpy as np +import numpy as np +import xarray as xr +import zarr from cloudcasting_app.app import app @@ -17,30 +17,30 @@ def test_app(sat_5_data, tmp_path, test_t0): os.environ["OUTPUT_PREDICTION_DIRECTORY"] = f"{tmp_path}" with zarr.storage.ZipStore("temp_sat.zarr.zip", mode="x") as store: - sat_5_data.to_zarr(store) + sat_5_data.to_zarr(store) app() - + # Check the two output files have been created latest_zarr_path = f"{tmp_path}/latest.zarr" t0_string_zarr_path = test_t0.strftime(f"{tmp_path}/%Y-%m-%dT%H:%M.zarr") assert os.path.exists(latest_zarr_path) assert os.path.exists(t0_string_zarr_path) - + # Load the predictions and check them ds_y_hat = xr.open_zarr(latest_zarr_path) - + assert "sat_pred" in ds_y_hat assert ( list(ds_y_hat.sat_pred.coords)== ["init_time", "step", "variable", "x_geostationary", "y_geostationary"] ) - + # Make sure all the coords are correct assert ds_y_hat.init_time == test_t0 assert len(ds_y_hat.step)==12 assert (ds_y_hat.x_geostationary==sat_5_data.x_geostationary).all() assert (ds_y_hat.y_geostationary==sat_5_data.y_geostationary).all() - + # Make sure all of the predictions are finite - assert np.isfinite(ds_y_hat.sat_pred).all() \ No newline at end of file + assert np.isfinite(ds_y_hat.sat_pred).all()