Skip to content

Commit

Permalink
array converter (#103)
Browse files Browse the repository at this point in the history
More complete draft of an array conversion function to expand sparse dictionary representations to a full array. Also some housekeeping, move things into their own modules.

Conversion function still needs testing, only the barest minimum here.
  • Loading branch information
wpbonelli authored Mar 8, 2025
1 parent 56a0267 commit 2a7af0d
Show file tree
Hide file tree
Showing 21 changed files with 552 additions and 567 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ jobs:
- name: Run ruff
run: pixi run lint

- name: Run mypy
run: pixi run mypy flopy4

build:
name: Build
runs-on: ubuntu-latest
Expand Down
101 changes: 0 additions & 101 deletions flopy4/mf6/__init__.py
Original file line number Diff line number Diff line change
@@ -1,101 +0,0 @@
from abc import ABC
from datetime import datetime
from pathlib import Path
from typing import Optional

import numpy as np
from attrs import define
from numpy.typing import NDArray
from xattree import ROOT, array, dim, field, xattree

__all__ = [
"Component",
"Package",
"Model",
"Simulation",
"Solution",
"Exchange",
"COMPONENTS",
]

COMPONENTS = {}
"""MF6 component registry."""


class Component(ABC):
@classmethod
def __attrs_init_subclass__(cls):
COMPONENTS[cls.__name__.lower()] = cls


@define
class Package(Component):
pass


@define
class Model(Component):
pass


@define
class Solution(Package):
pass


@define
class Exchange(Package):
exgtype: type = field()
exgfile: Path = field()
exgmnamea: Optional[str] = field(default=None)
exgmnameb: Optional[str] = field(default=None)


@xattree
class Tdis(Package):
@define
class PeriodData:
perlen: float
nstp: int
tsmult: float

nper: int = dim(
name="per",
default=1,
scope=ROOT,
metadata={"block": "dimensions"},
)
time_units: Optional[str] = field(
default=None, metadata={"block": "options"}
)
start_date_time: Optional[datetime] = field(
default=None, metadata={"block": "options"}
)
# perioddata: NDArray[np.object_] = array(
# PeriodData,
# dims=("per",),
# metadata={"block": "perioddata"},
# )
perlen: NDArray[np.floating] = array(
default=1.0,
dims=("per",),
metadata={"block": "perioddata"},
)
nstp: NDArray[np.integer] = array(
default=1,
dims=("per",),
metadata={"block": "perioddata"},
)
tsmult: NDArray[np.floating] = array(
default=1.0,
dims=("per",),
metadata={"block": "perioddata"},
)


@xattree
class Simulation(Component):
models: dict[str, Model] = field()
exchanges: dict[str, Exchange] = field()
solutions: dict[str, Solution] = field()
tdis: Tdis = field()
10 changes: 10 additions & 0 deletions flopy4/mf6/component.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from abc import ABC

COMPONENTS = {}
"""MF6 component registry."""


class Component(ABC):
@classmethod
def __attrs_init_subclass__(cls):
COMPONENTS[cls.__name__.lower()] = cls
4 changes: 4 additions & 0 deletions flopy4/mf6/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import numpy as np

FILL_DEFAULT = np.nan
FILL_DNODATA = 1e30
66 changes: 66 additions & 0 deletions flopy4/mf6/converters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@

import numpy as np
from numpy.typing import NDArray
from xattree import _get_xatspec

from flopy4.mf6.constants import FILL_DNODATA


def convert_array(value, self_, field) -> NDArray:
if not isinstance(value, dict):
# if not a dict, assume it's a numpy array
# and let xarray deal with it if it isn't
return value

# get spec
spec = _get_xatspec(type(self_))
field = spec.arrays[field.name]
if not field.dims:
raise ValueError(f"Field {field} missing dims")

# resolve dims
explicit_dims = self_.__dict__.get("dims", {})
inherited_dims = self_.parent.data.dims if self_.parent else {}
dims = inherited_dims | explicit_dims
shape = [dims.get(d, d) for d in field.dims]
unresolved = [d for d in shape if isinstance(d, str)]
if any(unresolved):
raise ValueError(f"Couldn't resolve dims: {unresolved}")

# create array
a = np.full(shape, fill_value=FILL_DNODATA, dtype=field.dtype)

def _get_nn(cellid):
match len(cellid):
case 1:
return cellid[0]
case 2:
k, j = cellid
return k * dims["ncpl"] + j
case 3:
k, i, j = cellid
return k * dims["row"] * dims["col"] + i * dims["col"] + j
case _:
raise ValueError(f"Invalid cellid: {cellid}")

# populate array. TODO: is there a way to do this
# without hardcoding awareness of kper and cellid?
if "per" in dims:
for kper, period in value.items():
if kper == "*":
kper = 0
match len(shape):
case 1:
a[kper] = value
case _:
for cellid, v in period.items():
nn = _get_nn(cellid)
a[kper, nn] = v
if kper == "*":
break
else:
for cellid, v in value.items():
nn = _get_nn(cellid)
a[nn] = v

return a
15 changes: 15 additions & 0 deletions flopy4/mf6/exchange.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from pathlib import Path
from typing import Optional

from attrs import define
from xattree import field

from flopy4.mf6.package import Package


@define
class Exchange(Package):
exgtype: type = field()
exgfile: Path = field()
exgmnamea: Optional[str] = field(default=None)
exgmnameb: Optional[str] = field(default=None)
2 changes: 1 addition & 1 deletion flopy4/mf6/gwf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from attrs import define
from xattree import field, xattree

from flopy4.mf6 import Model
from flopy4.mf6.gwf.chd import Chd
from flopy4.mf6.gwf.dis import Dis
from flopy4.mf6.gwf.ic import Ic
from flopy4.mf6.gwf.npf import Npf
from flopy4.mf6.gwf.oc import Oc
from flopy4.mf6.model import Model

__all__ = ["Gwf", "Chd", "Dis", "Ic", "Npf", "Oc"]

Expand Down
50 changes: 13 additions & 37 deletions flopy4/mf6/gwf/chd.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,16 @@
from pathlib import Path
from typing import Optional

import attrs
import numpy as np
from attrs import define
from attrs import Converter, define
from numpy.typing import NDArray
from xattree import _get_xatspec, array, field, xattree
from xattree import array, field, xattree

from flopy4.mf6 import Package
from flopy4.mf6.converters import convert_array
from flopy4.mf6.package import Package

dnodata = 1e30


def _get_nn(ncol, nrow, k, i, j):
return k * nrow * ncol + i * ncol + j


def _convert_array(value, self_, field):
if not isinstance(value, dict):
return value

inherited_dims = self_.__dict__.get("dims", {})
spec = _get_xatspec(type(self_))
field = spec.arrays["head"]
shape = field.dims
if not shape:
raise ValueError()
dims = [inherited_dims.get(d, d) for d in shape]
# TODO pull out dtype from annotation
a = np.full(dims, fill_value=dnodata, dtype=np.float64)
for kper, period in value.items():
if kper == "*":
kper = 0
for cellid, v in period.items():
nn = _get_nn(inherited_dims["col"], inherited_dims["row"], *cellid)
a[kper, nn] = v
return a


@xattree(multi="list")
@xattree
class Chd(Package):
@define(slots=False)
class Steps:
Expand Down Expand Up @@ -75,9 +47,7 @@ class Steps:
),
default=None,
metadata={"block": "period"},
converter=attrs.Converter(
_convert_array, takes_self=True, takes_field=True
),
converter=Converter(convert_array, takes_self=True, takes_field=True),
)
aux: Optional[NDArray[np.floating]] = array(
dims=(
Expand All @@ -86,6 +56,7 @@ class Steps:
),
default=None,
metadata={"block": "period"},
converter=Converter(convert_array, takes_self=True, takes_field=True),
)
boundname: Optional[NDArray[np.str_]] = array(
dims=(
Expand All @@ -94,7 +65,12 @@ class Steps:
),
default=None,
metadata={"block": "period"},
converter=Converter(convert_array, takes_self=True, takes_field=True),
)
steps: Optional[NDArray[np.object_]] = array(
Steps, dims=("per", "node"), default=None, metadata={"block": "period"}
Steps,
dims=("per", "node"),
default=None,
metadata={"block": "period"},
converter=Converter(convert_array, takes_self=True, takes_field=True),
)
9 changes: 8 additions & 1 deletion flopy4/mf6/gwf/dis.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import numpy as np
from attrs import Converter
from numpy.typing import NDArray
from xattree import array, dim, field, xattree

from flopy4.mf6 import Package
from flopy4.mf6.converters import convert_array
from flopy4.mf6.package import Package


@xattree
Expand Down Expand Up @@ -46,26 +48,31 @@ class Dis(Package):
dims=("col",),
default=1.0,
metadata={"block": "griddata"},
converter=Converter(convert_array, takes_self=True, takes_field=True),
)
delc: NDArray[np.floating] = array(
dims=("row",),
default=1.0,
metadata={"block": "griddata"},
converter=Converter(convert_array, takes_self=True, takes_field=True),
)
top: NDArray[np.floating] = array(
dims=("col", "row"),
default=1.0,
metadata={"block": "griddata"},
converter=Converter(convert_array, takes_self=True, takes_field=True),
)
botm: NDArray[np.floating] = array(
dims=("col", "row", "lay"),
default=0.0,
metadata={"block": "griddata"},
converter=Converter(convert_array, takes_self=True, takes_field=True),
)
idomain: NDArray[np.integer] = array(
dims=("col", "row", "lay"),
default=1,
metadata={"block": "griddata"},
converter=Converter(convert_array, takes_self=True, takes_field=True),
)
nnodes: int = dim(name="node", scope="gwf", init=False)

Expand Down
5 changes: 4 additions & 1 deletion flopy4/mf6/gwf/ic.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import numpy as np
from attrs import Converter
from numpy.typing import NDArray
from xattree import array, field, xattree

from flopy4.mf6 import Package
from flopy4.mf6.converters import convert_array
from flopy4.mf6.package import Package


@xattree
Expand All @@ -11,6 +13,7 @@ class Ic(Package):
dims=("node",),
default=1.0,
metadata={"block": "packagedata"},
converter=Converter(convert_array, takes_self=True, takes_field=True),
)
export_array_ascii: bool = field(
default=False, metadata={"block": "options"}
Expand Down
Loading

0 comments on commit 2a7af0d

Please sign in to comment.