Skip to content

Commit

Permalink
refactor: use StateFacade
Browse files Browse the repository at this point in the history
  • Loading branch information
aaraney committed Feb 7, 2025
1 parent 43c3774 commit fd64c7d
Showing 1 changed file with 13 additions and 18 deletions.
31 changes: 13 additions & 18 deletions lstm/bmi_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from . import nextgen_cuda_lstm
from .base import BmiBase
from .logger import configure_logging, logger
from .model_state import State, StateValues, Var
from .model_state import State, StateFacade, Var

# -------------- Dynamic Attributes -----------------------------
_dynamic_input_vars = [
Expand Down Expand Up @@ -387,7 +387,7 @@ def update(self) -> None:
#
# each ensemble member will query the `state` object for its required inputs.
# this could ensemble members with a different number of required features in the future.
state = StateValues(self._dynamic_inputs, self._static_inputs)
state = StateFacade(self._dynamic_inputs, self._static_inputs)

outputs: dict[str, list[float]] = collections.defaultdict(list)
for member in self.ensemble_members:
Expand Down Expand Up @@ -441,19 +441,15 @@ def get_output_var_names(self) -> tuple[str, ...]: # type: ignore

def get_var_grid(self, name: str) -> int:
# Note: all vars have grid 0 but check if its in names list first
if name in self._dynamic_inputs or name in self._outputs:
if name in StateFacade(self._outputs, self._dynamic_inputs):
return 0
raise RuntimeError(f"unknown name: {name!s}")

def get_var_type(self, name: str) -> str:
return self.get_value_ptr(name).dtype.name

def get_var_units(self, name: str) -> str:
if name in self._dynamic_inputs:
return self._dynamic_inputs.unit(name)
elif name in self._outputs:
return self._outputs.unit(name)
raise RuntimeError(f"unknown name: {name!s}")
return StateFacade(self._outputs, self._dynamic_inputs).unit(name)

def get_var_itemsize(self, name: str) -> int:
return self.get_value_ptr(name).itemsize
Expand All @@ -462,8 +458,7 @@ def get_var_nbytes(self, name: str) -> int:
return self.get_var_itemsize(name) * len(self.get_value_ptr(name))

def get_var_location(self, name: str) -> str:
# Note: all vars have location node but check if its in names list first
if name in self._dynamic_inputs or name in self._outputs:
if name in StateFacade(self._outputs, self._dynamic_inputs):
return "node"
raise RuntimeError(f"unknown name: {name!s}")

Expand All @@ -487,24 +482,24 @@ def get_value(self, name: str, dest: np.ndarray) -> np.ndarray:
return dest

def get_value_ptr(self, name: str) -> np.ndarray:
# NOTE: aaraney: I think we just want to allow getting outputs?
return self._outputs.value(name)
return StateFacade(self._outputs, self._dynamic_inputs).value(name)

def get_value_at_indices(
self, name: str, dest: np.ndarray, inds: np.ndarray
) -> np.ndarray:
# NOTE: aaraney: I think we just want to allow getting outputs?
return self._outputs.value_at_indices(name, dest, inds)
return StateFacade(self._outputs, self._dynamic_inputs).value_at_indices(
name, dest, inds
)

def set_value(self, name: str, src: np.ndarray) -> None:
# NOTE: aaraney: I think we just want to allow setting inputs?
self._dynamic_inputs.set_value(name, src)
return StateFacade(self._outputs, self._dynamic_inputs).set_value(name, src)

def set_value_at_indices(
self, name: str, inds: np.ndarray, src: np.ndarray
) -> None:
# NOTE: aaraney: I think we just want to allow setting inputs?
self._dynamic_inputs.set_value_at_indices(name, inds, src)
return StateFacade(self._outputs, self._dynamic_inputs).set_value_at_indices(
name, inds, src
)

# Grid information
def get_grid_rank(self, grid: int) -> int:
Expand Down

0 comments on commit fd64c7d

Please sign in to comment.