Skip to content

Commit

Permalink
minimal block impl
Browse files Browse the repository at this point in the history
  • Loading branch information
wpbonelli committed Jul 18, 2024
1 parent 4b1f425 commit f1bf5c2
Show file tree
Hide file tree
Showing 8 changed files with 522 additions and 119 deletions.
66 changes: 45 additions & 21 deletions flopy4/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from flopy.utils.flopy_io import line_strip, multi_line_strip

from flopy4.constants import CommonNames
from flopy4.parameter import MFParameter
from flopy4.parameter import MFParameter, MFReader


class NumPyArrayMixin:
Expand All @@ -23,7 +23,7 @@ class NumPyArrayMixin:
"""

def __iadd__(self, other):
if self._layered:
if self.layered:
for mfa in self._value:
mfa += other
return self
Expand All @@ -32,7 +32,7 @@ def __iadd__(self, other):
return self

def __imul__(self, other):
if self._layered:
if self.layered:
for mfa in self._value:
mfa *= other
return self
Expand All @@ -41,7 +41,7 @@ def __imul__(self, other):
return self

def __isub__(self, other):
if self._layered:
if self.layered:
for mfa in self._value:
mfa -= other
return self
Expand All @@ -50,7 +50,7 @@ def __isub__(self, other):
return self

def __itruediv__(self, other):
if self._layered:
if self.layered:
for mfa in self._value:
mfa /= other
return self
Expand All @@ -59,7 +59,7 @@ def __itruediv__(self, other):
return self

def __ifloordiv__(self, other):
if self._layered:
if self.layered:
for mfa in self._value:
mfa /= other
return self
Expand All @@ -68,7 +68,7 @@ def __ifloordiv__(self, other):
return self

def __ipow__(self, other):
if self._layered:
if self.layered:
for mfa in self._value:
mfa /= other
return self
Expand All @@ -77,7 +77,7 @@ def __ipow__(self, other):
return self

def __add__(self, other):
if self._layered:
if self.layered:
for mfa in self._value:
mfa += other
return self
Expand All @@ -86,7 +86,7 @@ def __add__(self, other):
return self

def __mul__(self, other):
if self._layered:
if self.layered:
for mfa in self._value:
mfa *= other
return self
Expand All @@ -95,7 +95,7 @@ def __mul__(self, other):
return self

def __sub__(self, other):
if self._layered:
if self.layered:
for mfa in self._value:
mfa -= other
return self
Expand All @@ -104,7 +104,7 @@ def __sub__(self, other):
return self

def __truediv__(self, other):
if self._layered:
if self.layered:
for mfa in self._value:
mfa /= other
return self
Expand All @@ -113,7 +113,7 @@ def __truediv__(self, other):
return self

def __floordiv__(self, other):
if self._layered:
if self.layered:
for mfa in self._value:
mfa /= other
return self
Expand All @@ -122,7 +122,7 @@ def __floordiv__(self, other):
return self

def __pow__(self, other):
if self._layered:
if self.layered:
for mfa in self._value:
mfa /= other
return self
Expand Down Expand Up @@ -187,26 +187,50 @@ def __init__(
shape,
how=MFArrayType.internal,
factor=None,
layered=False,
block=None,
name=None,
longname=None,
description=None,
deprecated=False,
in_record=False,
layered=False,
optional=True,
numeric_index=False,
preserve_case=False,
repeating=False,
tagged=False,
reader=MFReader.urword,
default_value=None,
):
MFParameter.__init__(self, name, longname, description, optional)
MFParameter.__init__(
self,
block=block,
name=name,
longname=longname,
description=description,
deprecated=deprecated,
in_record=in_record,
layered=layered,
optional=optional,
numeric_index=numeric_index,
preserve_case=preserve_case,
repeating=repeating,
tagged=tagged,
reader=reader,
default_value=default_value,
)
self._value = array
self._shape = shape
self._how = how
self._factor = factor
self._layered = layered

def __getitem__(self, item):
return self.raw[item]

def __setitem__(self, key, value):
values = self.raw
values[key] = value
if self._layered:
if self.layered:
for ix, mfa in enumerate(self._value):
mfa[:] = values[ix]
return
Expand Down Expand Up @@ -247,7 +271,7 @@ def value(self) -> np.ndarray:
"""
Return the array.
"""
if self._layered:
if self.layered:
arr = []
for mfa in self._value:
arr.append(mfa.value)
Expand All @@ -263,7 +287,7 @@ def raw(self):
"""
Return the array without multiplying by `self.factor`.
"""
if self._layered:
if self.layered:
arr = []
for mfa in self._value:
arr.append(mfa.raw)
Expand All @@ -279,7 +303,7 @@ def factor(self) -> Optional[float]:
"""
Optional factor by which to multiply array elements.
"""
if self._layered:
if self.layered:
factor = [mfa.factor for mfa in self._value]
return factor

Expand All @@ -293,7 +317,7 @@ def how(self):
"""
How the array is to be written to the input file.
"""
if self._layered:
if self.layered:
how = [mfa.how for mfa in self._value]
return how

Expand Down
88 changes: 88 additions & 0 deletions flopy4/block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from collections.abc import MutableMapping
from typing import Any, Dict

from flopy4.parameter import MFParameter
from flopy4.utils import strip


def get_member_params(cls) -> Dict[str, MFParameter]:
if not issubclass(cls, MFBlock):
raise ValueError(f"Expected MFBlock, got {cls}")

return {
k: v
for k, v in cls.__dict__.items()
if issubclass(type(v), MFParameter)
}


class MFBlock(MutableMapping):
def __init__(self, name=None, index=None, *args, **kwargs):
self.name = name
self.index = index
self.params = dict()
self.update(dict(*args, **kwargs))
for key, param in self.items():
setattr(self, key, param)

def __getattribute__(self, name: str) -> Any:
attr = super().__getattribute__(name)
if isinstance(attr, MFParameter):
# shortcut to parameter value for instance attributes.
# the class attribute is the full param specification.
return attr.value
else:
return attr

def __getitem__(self, key):
return self.params[key]

def __setitem__(self, key, value):
self.params[key] = value

def __delitem__(self, key):
del self.params[key]

def __iter__(self):
return iter(self.params)

def __len__(self):
return len(self.params)

@classmethod
def load(cls, f):
name = None
index = None
found = False
params = dict()
members = get_member_params(cls)
while True:
pos = f.tell()
line = strip(f.readline()).lower()
words = line.split()
key = words[0]
if key == "begin":
found = True
name = words[1]
if len(words) > 2 and str.isdigit(words[2]):
index = words[2]
elif key == "end":
break
elif found:
if key in members:
f.seek(pos)
param = members[key]
param.block = name
params[key] = type(param).load(f, spec=param)

return cls(name, index, **params)

def write(self, f):
index = self.index if self.index is not None else ""
begin = f"BEGIN {self.name.upper()} {index}\n"
end = f"END {self.name.upper()}\n"

f.write(begin)
for param in self.values():
param.write(f)
f.write(end)
Loading

0 comments on commit f1bf5c2

Please sign in to comment.