Skip to content

Commit

Permalink
Updates for working output to neural-lam (#6)
Browse files Browse the repository at this point in the history
* adapt pyg writing for neural-lam

* fix missing and circular imports

* make regular input grid assumption explicit

* handle missing edge attributes in splitting

* minor fix for neural-lam save
  • Loading branch information
leifdenby authored May 6, 2024
1 parent 895fd70 commit 940609b
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/weather_model_graphs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from . import create, visualise
from . import create, save, visualise
from .networkx_utils import (
replace_node_labels_with_unique_ids,
split_graph_by_edge_attribute,
Expand Down
2 changes: 1 addition & 1 deletion src/weather_model_graphs/create/archetype.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from . import create_all_graph_components
from .base import create_all_graph_components


def create_keisler_graph(xy_grid, merge_components=True):
Expand Down
6 changes: 6 additions & 0 deletions src/weather_model_graphs/create/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@ def create_all_graph_components(
"""
graph_components: dict[networkx.DiGraph] = {}

if len(xy.shape) != 3:
raise NotImplementedError(
"Mesh coordinates are assumed to lie on a regular grid so that "
"the coordinates values are given with an array of shape [2, nx, ny]"
)

if m2m_connectivity == "flat":
logger.warning(
"Using refinement factor 2 between grid and mesh nodes for flat mesh graph"
Expand Down
6 changes: 6 additions & 0 deletions src/weather_model_graphs/create/grid/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ def create_grid_graph_nodes(xy, level_id=-1):
networkx.Graph
Graph representing the grid nodes
"""
if len(xy.shape) != 3:
raise NotImplementedError(
"Mesh coordinates are assumed to lie on a regular grid so that "
"the coordinates values are given with an array of shape [2, nx, ny]"
)

# grid nodes
Ny, Nx = xy.shape[1:]

Expand Down
21 changes: 21 additions & 0 deletions src/weather_model_graphs/networkx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ def sort_nodes_internally(nx_graph, node_attribute=None, edge_attribute=None):
return H


class MissingEdgeAttributeError(Exception):
pass


def split_graph_by_edge_attribute(graph, attribute):
"""
Split a graph into subgraphs based on an edge attribute, returning
Expand All @@ -61,6 +65,12 @@ def split_graph_by_edge_attribute(graph, attribute):
Dictionary of subgraphs keyed by edge attribute value
"""

# check if any node has the attribute
if not any(attribute in graph.edges[edge] for edge in graph.edges):
raise MissingEdgeAttributeError(
f"Edge attribute '{attribute}' not found in graph. Check the attribute."
)

# Get unique edge attribute values
edge_values = set(networkx.get_edge_attributes(graph, attribute).values())

Expand All @@ -71,6 +81,17 @@ def split_graph_by_edge_attribute(graph, attribute):
[edge for edge in graph.edges if graph.edges[edge][attribute] == edge_value]
)

# copy node attributes
for subgraph in subgraphs.values():
for node in subgraph.nodes:
subgraph.nodes[node].update(graph.nodes[node])

# check that at least one subgraph was created
if len(subgraphs) == 0:
raise ValueError(
f"No subgraphs were created. Check the edge attribute '{attribute}'."
)

return subgraphs


Expand Down
62 changes: 51 additions & 11 deletions src/weather_model_graphs/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from loguru import logger

from .networkx_utils import split_graph_by_edge_attribute
from .networkx_utils import MissingEdgeAttributeError, split_graph_by_edge_attribute

try:
import torch_geometric.utils.convert as pyg_convert
Expand All @@ -16,10 +16,16 @@


def to_pyg(
graph: networkx.DiGraph, output_directory: str, name: str, list_from_attribute=None
graph: networkx.DiGraph,
output_directory: str,
name: str,
edge_features=["vdiff"],
node_features=["pos"],
list_from_attribute=None,
):
"""
Save the networkx graph to PyTorch Geometric format.
Save the networkx graph to PyTorch Geometric format that matches what the
neural-lam model expects as input
Parameters
----------
Expand All @@ -36,6 +42,10 @@ def to_pyg(
stored edge index and features are then the concatenation of the split graphs,
so that a separate pyg.Data object can be created for each subgraph
(e.g. one for each level in a multi-level graph). Default is None.
edge_features: List[str]
list of edge attributes to include in `{name}_edge_features.pt` file
node_features: List[str]
list of node attributes to include in `{name}_node_features.pt` file
Returns
-------
Expand All @@ -53,33 +63,63 @@ def to_pyg(
if len(set(graph.nodes)) != len(graph.nodes):
raise ValueError("Node labels must be unique.")

# remove all node attributes but the ones we want to keep
for node in graph.nodes:
for attr in list(graph.nodes[node].keys()):
if attr not in node_features:
del graph.nodes[node][attr]

def _get_edge_indecies(pyg_g):
return pyg_g.edge_index

def _get_edge_features(pyg_g):
if edge_features != ["vdiff"]:
raise NotImplementedError(edge_features_values)
# TODO: handle features of different types more generally, i.e. both single ("len") values and tuples (like "vdiff")
return torch.cat((pyg_g.len.unsqueeze(1), pyg_g.vdiff), dim=1).to(torch.float32)

def _get_node_features(pyg_g):
if node_features != ["pos"]:
raise NotImplementedError(node_features_values)
return pyg_g.pos.to(torch.float32)

if list_from_attribute is not None:
# create a list of graph objects by splitting the graph by the list_from_attribute
sub_graphs = split_graph_by_edge_attribute(
graph=graph, attribute=list_from_attribute
)
try:
sub_graphs = list(
split_graph_by_edge_attribute(
graph=graph, attribute=list_from_attribute
).values()
)
except MissingEdgeAttributeError:
# neural-lam still expects a list of graphs, so if the attribute is missing
# we just return the original graph as a list
sub_graphs = [graph]
pyg_graphs = [pyg_convert.from_networkx(g) for g in sub_graphs]
else:
pyg_graphs = [pyg_convert.from_networkx(graph)]

edge_features = [_get_edge_features(pyg_g) for pyg_g in pyg_graphs]
edge_features_values = [_get_edge_features(pyg_g) for pyg_g in pyg_graphs]
edge_indecies = [_get_edge_indecies(pyg_g) for pyg_g in pyg_graphs]
node_features_values = [_get_node_features(pyg_g) for pyg_g in pyg_graphs]

if len(pyg_graphs) == 1:
edge_features = edge_features[0]
if list_from_attribute is None:
edge_features_values = edge_features_values[0]
edge_indecies = edge_indecies[0]

Path(output_directory).mkdir(exist_ok=True, parents=True)
fp_edge_index = Path(output_directory) / f"{name}_edge_index.pt"
fp_features = Path(output_directory) / f"{name}_features.pt"
torch.save(edge_indecies, fp_edge_index)
torch.save(edge_features, fp_features)
logger.info(f"Saved edge index and features to {fp_edge_index} and {fp_features}.")
torch.save(edge_features_values, fp_features)
logger.info(
f"Saved edge index to {fp_edge_index} and features {edge_features} to {fp_features}."
)

# save node features
fp_node_features = Path(output_directory) / f"{name}_node_features.pt"
torch.save(node_features_values, fp_node_features)
logger.info(f"Saved node features {node_features} to {fp_node_features}.")


def to_pickle(graph: networkx.DiGraph, output_directory: str, name: str):
Expand Down

0 comments on commit 940609b

Please sign in to comment.