Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Setup ci/cd linting #3

Merged
merged 7 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[flake8]
max-line-length = 88
extend-ignore = E203
select = C,E,F,W,B,B950
ignore = E203, E501, W503
per-file-ignores = __init__.py:F401
16 changes: 16 additions & 0 deletions .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
name: linting

on:
push:
branches: "*"
pull_request:
branches: "*"

jobs:
linting:
name: "pre-commit hooks"
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
- uses: pre-commit/[email protected]
21 changes: 21 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# https://pre-commit.com/
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.0.1
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
# isort should run before black as black sometimes tweaks the isort output
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
hooks:
- id: isort
# https://github.com/python/black#version-control-integration
- repo: https://github.com/psf/black
rev: 22.3.0
hooks:
- id: black
- repo: https://github.com/PyCQA/flake8
rev: 6.1.0
hooks:
- id: flake8
16 changes: 9 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# weather-model-graphs

[![linting](https://github.com/mllam/weather-model-graphs/actions/workflows/pre-commit.yml/badge.svg)](https://github.com/mllam/weather-model-graphs/actions/workflows/pre-commit.yml)

`weather-model-graphs` is a package for creating, visualising and storing message-passing graphs for data-driven weather models.

The package is designed to use `networkx.DiGraph` objects as the primary data structure for the graph representation right until the graph is to be stored on disk into a specific format.
Expand Down Expand Up @@ -29,7 +31,7 @@ pdm install --group pytorch

# Usage

The best way to understand how to use `weather-model-graphs` is to look at the [notebooks/constructing_the_graph.ipynb](notebooks/constructing_the_graph.ipynb) notebook, to have look at the tests in [tests/](tests/) or simply to read through the source code.
The best way to understand how to use `weather-model-graphs` is to look at the [notebooks/constructing_the_graph.ipynb](notebooks/constructing_the_graph.ipynb) notebook, to have look at the tests in [tests/](tests/) or simply to read through the source code.
In addition you can read the [background and design](#background-and-design) section below to understand the design principles of `weather-model-graphs`.

## Example, Keisler 2021 flat graph architecture
Expand Down Expand Up @@ -101,9 +103,9 @@ The graph generation in `weather-model-graphs` is split into to the following st
- **networkx** `.pickle` file: save `networkx.DiGraph` objects using `pickle` to disk (`weather_model_graphs.save.to_pickle(...)`)

- [pytorch-geometric](https://github.com/pyg-team/pytorch_geometric) for [neural-lam](https://github.com/mllam/neural-lam): edges indexes and features are stored in separate `torch.Tensor` objects serialised to disk that can then be loaded into `torch_geometric.data.Data` objects (`weather_model_graphs.save.to_pyg(...)`

### Diagram of the graph generation process:

Below visualises the graph generation process in `weather-model-graphs` for the example given above:

```mermaid
Expand Down Expand Up @@ -186,9 +188,9 @@ The code layout of `weather-model-graphs` is organise into submodules by the fun
```
weather_model_graphs
.create
.archetype:
for creating specific archetype graph
architectures (e.g. Keisler 2021, Lam et al 2023,
.archetype:
for creating specific archetype graph
architectures (e.g. Keisler 2021, Lam et al 2023,
Oscarsson et al 2023)
.base
general interface for creating graph architectures
Expand All @@ -198,7 +200,7 @@ weather_model_graphs
.grid
for creating the grid nodes
.visualise
for plotting graphs, allowing for easy visualisation using any
for plotting graphs, allowing for easy visualisation using any
edge or node attribute for colouring
.save
for saving the graph to specific formats (e.g. pytorch-geometric)
4 changes: 3 additions & 1 deletion notebooks/constructing_the_graph.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,9 @@
],
"source": [
"m2m_graph = graph_components.pop(\"m2m\")\n",
"m2m_graph_components = wmg.split_graph_by_edge_attribute(graph=m2m_graph, attribute=\"direction\")\n",
"m2m_graph_components = wmg.split_graph_by_edge_attribute(\n",
" graph=m2m_graph, attribute=\"direction\"\n",
")\n",
"m2m_graph_components = {\n",
" f\"m2m_{name}\": graph for name, graph in m2m_graph_components.items()\n",
"}\n",
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ visualisation = [
requires = ["pdm-backend"]
build-backend = "pdm.backend"

[tool.isort]
profile = "black"

[tool.pdm]
distribution = true
Expand Down
5 changes: 4 additions & 1 deletion src/weather_model_graphs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
from . import create, visualise
from .networkx_utils import split_graph_by_edge_attribute, replace_node_labels_with_unique_ids
from .networkx_utils import (
replace_node_labels_with_unique_ids,
split_graph_by_edge_attribute,
)
2 changes: 1 addition & 1 deletion src/weather_model_graphs/create/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from . import archetype
from .base import create_all_graph_components
from . import archetype
26 changes: 16 additions & 10 deletions src/weather_model_graphs/create/archetype.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,13 @@ def create_keisler_graph(xy_grid, merge_components=True):
g2m_connectivity="nearest_neighbours",
g2m_connectivity_kwargs=dict(
max_num_neighbours=4,
)
),
)

def create_graphcast_graph(xy_grid, refinement_factor=3, max_num_levels=None, merge_components=True):

def create_graphcast_graph(
xy_grid, refinement_factor=3, max_num_levels=None, merge_components=True
):
"""
Create a graph following the Lam et al (2023, https://arxiv.org/abs/2212.12794) GraphCast architecture.

Expand All @@ -46,7 +49,7 @@ def create_graphcast_graph(xy_grid, refinement_factor=3, max_num_levels=None, me
to its nearest 4 grid points. The mesh to grid connectivity connects each grid point to the nearest mesh node.

TODO: Verify that GraphCast does in fact use these g2m and m2g connectivities.

Parameters
----------
xy_grid: np.ndarray
Expand All @@ -69,14 +72,17 @@ def create_graphcast_graph(xy_grid, refinement_factor=3, max_num_levels=None, me
xy=xy_grid,
merge_components=merge_components,
m2m_connectivity="flat_multiscale",
m2m_connectivity_kwargs=dict(refinement_factor=refinement_factor, max_num_levels=max_num_levels),
m2m_connectivity_kwargs=dict(
refinement_factor=refinement_factor, max_num_levels=max_num_levels
),
m2g_connectivity="nearest_neighbour",
g2m_connectivity="nearest_neighbours",
g2m_connectivity_kwargs=dict(
max_num_neighbours=4,
)
),
)


def create_oscarsson_hierarchical_graph(xy_grid, merge_components=True):
"""
Create a graph following Oscarsson et al (2023, https://arxiv.org/abs/2309.17370)
Expand All @@ -89,20 +95,20 @@ def create_oscarsson_hierarchical_graph(xy_grid, merge_components=True):
edge connections each edge has a `direction` attribute (with value "up",
"down", or "same"). In addition the `level` attribute indicates which two levels
are connected for cross-level edges (e.g. "1>2" for edges between level 1 and 2).

The grid to mesh connectivity connects each mesh node to the four nearest
grid points, and the mesh to grid connectivity connects each grid point to
the nearest mesh node.

TODO: Is this the right connectivity for the g2m and m2g components?

Parameters
----------
xy_grid: np.ndarray
2D array of grid point positions.
merge_components: bool
Whether to merge the components of the graph.

Returns
-------
networkx.DiGraph or dict[networkx.DiGraph]
Expand All @@ -117,5 +123,5 @@ def create_oscarsson_hierarchical_graph(xy_grid, merge_components=True):
g2m_connectivity="nearest_neighbours",
g2m_connectivity_kwargs=dict(
max_num_neighbours=4,
)
),
)
Loading
Loading