Skip to content

Commit

Permalink
add basic MFArray tests
Browse files Browse the repository at this point in the history
Co-authored-by: mjreno <[email protected]>
  • Loading branch information
2 people authored and wpbonelli committed Jul 20, 2024
1 parent c5a5542 commit e500ac4
Show file tree
Hide file tree
Showing 10 changed files with 177 additions and 67 deletions.
12 changes: 6 additions & 6 deletions docs/examples/array_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
# Open and load a NumPy array representation

fhandle = open(internal)
imfa = MFArray.load(fhandle, data_path, shape)
imfa = MFArray.load(fhandle, data_path, shape, header=False)

# Get values

Expand All @@ -70,7 +70,7 @@
plt.colorbar()

fhandle = open(constant)
cmfa = MFArray.load(fhandle, data_path, shape)
cmfa = MFArray.load(fhandle, data_path, shape, header=False)
cvals = cmfa.value
plt.imshow(cvals[0:100])
plt.colorbar()
Expand All @@ -93,7 +93,7 @@
# External

fhandle = open(external)
emfa = MFArray.load(fhandle, data_path, shape)
emfa = MFArray.load(fhandle, data_path, shape, header=False)
evals = emfa.value
evals

Expand All @@ -118,7 +118,7 @@

fhandle = open(ilayered)
shape = (3, 1000, 100)
ilmfa = MFArray.load(fhandle, data_path, shape, layered=True)
ilmfa = MFArray.load(fhandle, data_path, shape, header=False, layered=True)
vals = ilmfa.value

ilmfa._value # internal storage
Expand Down Expand Up @@ -165,7 +165,7 @@

fhandle = open(clayered)
shape = (3, 1000, 100)
clmfa = MFArray.load(fhandle, data_path, shape, layered=True)
clmfa = MFArray.load(fhandle, data_path, shape, header=False, layered=True)

clmfa._value

Expand Down Expand Up @@ -218,7 +218,7 @@

fhandle = open(mlayered)
shape = (3, 1000, 100)
mlmfa = MFArray.load(fhandle, data_path, shape, layered=True)
mlmfa = MFArray.load(fhandle, data_path, shape, header=False, layered=True)

mlmfa.how

Expand Down
61 changes: 25 additions & 36 deletions flopy4/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,8 @@ class MFArray(MFParameter, NumPyArrayMixin):

def __init__(
self,
array,
shape,
array=None,
how=MFArrayType.internal,
factor=None,
block=None,
Expand Down Expand Up @@ -328,53 +328,42 @@ def write(self, f):
pass

@classmethod
def load(cls, f, cwd, shape, layered=False):
"""
Parameters
----------
f
def load(cls, f, cwd, shape, header=True, **kwargs):
layered = kwargs.pop("layered", False)

if header:
tokens = multi_line_strip(f).split()
name = tokens[0]
kwargs.pop("name", None)
if len(tokens) > 1 and tokens[1] == "layered":
layered = True
else:
name = kwargs.pop("name", None)

Returns
-------
MFArray
"""
if layered:
nlay = shape[0]
lay_shape = shape[1:]
lshp = shape[1:]
objs = []
for _ in range(nlay):
mfa = cls._load(f, cwd, lay_shape)
mfa = cls._load(f, cwd, lshp, name)
objs.append(mfa)

mfa = MFArray(
np.array(objs, dtype=object),
return MFArray(
shape,
array=np.array(objs, dtype=object),
how=None,
factor=None,
name=name,
layered=True,
)

else:
mfa = cls._load(f, cwd, shape, layered=layered)

return mfa
kwargs.pop("layered", None)
return cls._load(
f, cwd, shape, layered=layered, name=name, **kwargs
)

@classmethod
def _load(cls, f, cwd, shape, layered=False):
"""
Parameters
----------
f
cwd
shape
layered
Returns
-------
"""
def _load(cls, f, cwd, shape, layered=False, **kwargs):
control_line = multi_line_strip(f).split()

if CommonNames.iprn.lower() in control_line:
Expand All @@ -392,7 +381,7 @@ def _load(cls, f, cwd, shape, layered=False):
array = float(control_line[clpos])
clpos += 1

elif how == how.external:
elif how == MFArrayType.external:
ext_path = Path(control_line[clpos])
fpath = cwd / ext_path
with open(fpath) as foo:
Expand All @@ -406,8 +395,7 @@ def _load(cls, f, cwd, shape, layered=False):
if len(control_line) > 2:
factor = float(control_line[clpos + 1])

mfa = cls(array, shape, how, factor=factor)
return mfa
return cls(shape, array=array, how=how, factor=factor, **kwargs)

@staticmethod
def read_array(f):
Expand All @@ -433,6 +421,7 @@ def read_array(f):
CommonNames.internal in line
or CommonNames.external in line
or CommonNames.constant in line
or CommonNames.end in line.upper()
):
f.seek(pos, 0)
break
Expand Down
28 changes: 24 additions & 4 deletions flopy4/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from dataclasses import asdict
from typing import Any

from flopy4.array import MFArray
from flopy4.parameter import MFParameter, MFParameters
from flopy4.utils import strip

Expand Down Expand Up @@ -30,6 +31,21 @@ class MFBlockMappingMeta(MFBlockMeta, ABCMeta):


class MFBlock(MFParameters, metaclass=MFBlockMappingMeta):
"""
MF6 input block. Maps parameter names to parameters.
Notes
-----
This class is dynamically subclassed by `MFPackage`
to match each block within a package parameter set.
Supports dictionary and attribute access. The class
attributes specify the block's parameters. Instance
attributes contain both the specification and value.
The block's name and index are discovered upon load.
"""

def __init__(self, name=None, index=None, params=None):
self.name = name
self.index = index
Expand All @@ -49,7 +65,7 @@ def params(self) -> MFParameters:
return self.data

@classmethod
def load(cls, f):
def load(cls, f, **kwargs):
name = None
index = None
found = False
Expand All @@ -72,9 +88,13 @@ def load(cls, f):
param = members.get(key)
if param is not None:
f.seek(pos)
params[key] = type(param).load(
f, **asdict(param.with_name(key).with_block(name))
)
spec = asdict(param.with_name(key).with_block(name))
kwargs = {**kwargs, **spec}
if type(param) is MFArray:
# TODO: inject from model somehow?
# and remove special handling here
kwargs["cwd"] = ""
params[key] = type(param).load(f, **kwargs)

return cls(name, index, params)

Expand Down
1 change: 1 addition & 0 deletions flopy4/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ class CommonNames:
vertex = "vertex"
unstructured = "unstructured"
empty = ""
end = "END"
5 changes: 4 additions & 1 deletion flopy4/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ class MFPackageMappingMeta(MFPackageMeta, ABCMeta):

class MFPackage(UserDict, metaclass=MFPackageMappingMeta):
"""
MF6 package base class.
MF6 model or simulation component package.
TODO: reimplement with `ChainMap`?
"""
Expand Down Expand Up @@ -93,6 +94,8 @@ def load(cls, f):
line = f.readline()
if line == "":
break
if line == "\n":
continue
line = strip(line).lower()
words = line.split()
key = words[0]
Expand Down
13 changes: 11 additions & 2 deletions flopy4/parameter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from abc import abstractmethod
from ast import literal_eval
from collections import UserDict
from dataclasses import dataclass, fields
from enum import Enum
from typing import Any, Optional
from typing import Any, Optional, Tuple


class MFReader(Enum):
Expand Down Expand Up @@ -40,6 +41,7 @@ class MFParamSpec:
repeating: bool = False
tagged: bool = True
reader: MFReader = MFReader.urword
shape: Optional[Tuple[int]] = None
default_value: Optional[Any] = None

@classmethod
Expand Down Expand Up @@ -74,6 +76,8 @@ def load(cls, f) -> "MFParamSpec":
spec[key] = val == "true"
elif key == "reader":
spec[key] = MFReader.from_str(val)
elif key == "shape":
spec[key] = literal_eval(val)
else:
spec[key] = val
return cls(**spec)
Expand Down Expand Up @@ -134,6 +138,7 @@ def __init__(
repeating=False,
tagged=False,
reader=MFReader.urword,
shape=None,
default_value=None,
):
super().__init__(
Expand All @@ -150,6 +155,7 @@ def __init__(
repeating=repeating,
tagged=tagged,
reader=reader,
shape=shape,
default_value=default_value,
)

Expand All @@ -161,7 +167,10 @@ def value(self) -> Optional[Any]:


class MFParameters(UserDict):
"""Mapping of parameter names to parameters."""
"""
Mapping of parameter names to parameters.
Supports dictionary and attribute access.
"""

def __init__(self, params=None):
super().__init__(params)
Expand Down
12 changes: 12 additions & 0 deletions flopy4/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(
repeating=False,
tagged=False,
reader=MFReader.urword,
shape=None,
default_value=None,
):
self._value = value
Expand All @@ -43,6 +44,7 @@ def __init__(
repeating,
tagged,
reader,
shape,
default_value,
)

Expand All @@ -68,6 +70,7 @@ def __init__(
repeating=False,
tagged=False,
reader=MFReader.urword,
shape=None,
default_value=None,
):
super().__init__(
Expand All @@ -85,6 +88,7 @@ def __init__(
repeating,
tagged,
reader,
shape,
default_value,
)

Expand Down Expand Up @@ -122,6 +126,7 @@ def __init__(
repeating=False,
tagged=False,
reader=MFReader.urword,
shape=None,
default_value=None,
):
super().__init__(
Expand All @@ -139,6 +144,7 @@ def __init__(
repeating,
tagged,
reader,
shape,
default_value,
)

Expand Down Expand Up @@ -174,6 +180,7 @@ def __init__(
repeating=False,
tagged=False,
reader=MFReader.urword,
shape=None,
default_value=None,
):
super().__init__(
Expand All @@ -191,6 +198,7 @@ def __init__(
repeating,
tagged,
reader,
shape,
default_value,
)

Expand Down Expand Up @@ -226,6 +234,7 @@ def __init__(
repeating=False,
tagged=False,
reader=MFReader.urword,
shape=None,
default_value=None,
):
super().__init__(
Expand All @@ -243,6 +252,7 @@ def __init__(
repeating,
tagged,
reader,
shape,
default_value,
)

Expand Down Expand Up @@ -290,6 +300,7 @@ def __init__(
repeating=False,
tagged=False,
reader=MFReader.urword,
shape=None,
default_value=None,
):
self.inout = inout
Expand All @@ -308,6 +319,7 @@ def __init__(
repeating,
tagged,
reader,
shape,
default_value,
)

Expand Down
Loading

0 comments on commit e500ac4

Please sign in to comment.