Skip to content

Commit

Permalink
Refactoring state and checkpointing (#34)
Browse files Browse the repository at this point in the history
* Add a Checkpoint class 
* Add documentation to functions controlling checkpointing
* Add SolverState object to wrap FEniCS solver's state update
* Use is_timestep_complete API function
* Use module logger
* Update tests
  • Loading branch information
BenjaminRodenberg authored Nov 27, 2019
1 parent 1e675e0 commit 550ca38
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 41 deletions.
31 changes: 31 additions & 0 deletions fenicsadapter/checkpointing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from .solverstate import SolverState

class Checkpoint:

def __init__(self):
"""
A checkpoint for the solver state
"""
self._state = None

def get_state(self):
return self._state

def write(self, new_state):
"""
write checkpoint from solver state.
:param u: function value
:param t: time
:param n: timestep
"""
if self.is_empty():
self._state = SolverState(None, None, None)

self._state.copy(new_state)

def is_empty(self):
"""
Returns whether checkpoint is empty. An empty checkpoint has no state saved.
:return:
"""
return not self._state
89 changes: 58 additions & 31 deletions fenicsadapter/fenicsadapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,20 @@
:raise ImportError: if PRECICE_ROOT is not defined
"""
import dolfin
from dolfin import Point, UserExpression, SubDomain, Function, Measure, Expression, dot, PointSource, FacetNormal
from dolfin import UserExpression, SubDomain, Function
from scipy.interpolate import Rbf
from scipy.interpolate import interp1d
import numpy as np
from .config import Config
from .checkpointing import Checkpoint
from .solverstate import SolverState
from enum import Enum
import logging

logger = logging.getLogger(__name__)
logger.setLevel(level=logging.INFO)


try:
import precice_future as precice
except ImportError:
Expand Down Expand Up @@ -182,7 +188,7 @@ def create_interpolant(self):
else:
raise Exception("Problem dimension and data dimension not matching.")
elif self._dimension == 3:
logging.warning("RBF Interpolation for 3D Simulations has not been properly tested!")
logger.warning("RBF Interpolation for 3D Simulations has not been properly tested!")
if self.is_scalar_valued():
interpolant.append(Rbf(self._coords_x, self._coords_y, self._coords_z, self._vals.flatten()))
elif self.is_vector_valued():
Expand Down Expand Up @@ -298,9 +304,7 @@ def __init__(self, adapter_config_filename='precice-adapter-config.json',
self._my_expression = interpolation_strategy

# checkpointing
self._u_cp = None # checkpoint for temperature inside domain
self._t_cp = None # time of the checkpoint
self._n_cp = None # timestep of the checkpoint
self._checkpoint = Checkpoint()

# function space
self._function_space = None
Expand Down Expand Up @@ -584,12 +588,40 @@ def _get_forces_as_point_sources(self):

return x_forces.values(), y_forces.values() # don't return dictionary, but list of PointSources

def _restore_solver_state_from_checkpoint(self, state):
"""Resets the solver's state to the checkpoint's state.
:param state: current state of the FEniCS solver
"""
logger.debug("Restore solver state")
state.update(self._checkpoint.get_state())
self._interface.fulfilled_action(precice.action_read_iteration_checkpoint())

def _advance_solver_state(self, state, u_np1, dt):
"""Advances the solver's state by one timestep.
:param state: old state
:param u_np1: new value
:param dt: timestep size
:return:
"""
logger.debug("Advance solver state")
logger.debug("old state: t={time}".format(time=state.t))
state.update(SolverState(u_np1, state.t + dt, state.n + 1))
logger.debug("new state: t={time}".format(time=state.t))

def _save_solver_state_to_checkpoint(self, state):
"""Writes given solver state to checkpoint.
:param state: state being saved as checkpoint
"""
logger.debug("Save solver state")
self._checkpoint.write(state)
self._interface.fulfilled_action(precice.action_write_iteration_checkpoint())

def advance(self, write_function, u_np1, u_n, t, dt, n):
"""Calls preCICE advance function using precice and manages checkpointing.
The solution u_n is updated by this function via call-by-reference. The corresponding values for t and n are returned.
This means:
* either, the checkpoint self._u_cp is assigned to u_n to repeat the iteration,
* either, the old value of the checkpoint is assigned to u_n to repeat the iteration,
* or u_n+1 is assigned to u_n and the checkpoint is updated correspondingly.
:param write_function: a FEniCS function being sent to the other participant as boundary condition at the coupling interface
Expand All @@ -601,6 +633,8 @@ def advance(self, write_function, u_np1, u_n, t, dt, n):
:return: return starting time t and timestep n for next FEniCS solver iteration. u_n is updated by advance correspondingly.
"""

state = SolverState(u_n, t, n)

# sample write data at interface
x_vert, y_vert = self._extract_coupling_boundary_coordinates()
self._write_data = self._convert_fenics_to_precice(write_function)
Expand All @@ -618,29 +652,24 @@ def advance(self, write_function, u_np1, u_n, t, dt, n):
else:
self._coupling_bc_expression.update_boundary_data(self._read_data, x_vert, y_vert)

precice_step_complete = False
solver_state_has_been_restored = False

# checkpointing
if self._interface.is_action_required(precice.action_read_iteration_checkpoint()):
# continue FEniCS computation from checkpoint
u_n.assign(self._u_cp) # set u_n to value of checkpoint
t = self._t_cp
n = self._n_cp
self._interface.fulfilled_action(precice.action_read_iteration_checkpoint())
assert (not self._interface.is_timestep_complete()) # avoids invalid control flow
self._restore_solver_state_from_checkpoint(state)
solver_state_has_been_restored = True
else:
u_n.assign(u_np1)
t = new_t = t + dt # TODO: the variables new_t, new_n could be saved, by just using t and n below, however I think it improved readability.
n = new_n = n + 1
self._advance_solver_state(state, u_np1, dt)

if self._interface.is_action_required(precice.action_write_iteration_checkpoint()):
# continue FEniCS computation with u_np1
# update checkpoint
self._u_cp.assign(u_np1)
self._t_cp = new_t
self._n_cp = new_n
self._interface.fulfilled_action(precice.action_write_iteration_checkpoint())
precice_step_complete = True
assert (not solver_state_has_been_restored) # avoids invalid control flow
assert (self._interface.is_timestep_complete()) # avoids invalid control flow
self._save_solver_state_to_checkpoint(state)

precice_step_complete = self._interface.is_timestep_complete()

_, t, n = state.get_state()
# TODO: this if-else statement smells.
if self._has_force_boundary:
return t, n, precice_step_complete, max_dt, x_forces, y_forces
Expand Down Expand Up @@ -673,10 +702,10 @@ def initialize(self, coupling_subdomain, mesh, read_field, write_field,
self._fenics_dimensions = dimension

if self._fenics_dimensions != self._dimensions:
logging.warning("fenics_dimension = {} and precice_dimension = {} do not match!".format(self._fenics_dimensions,
logger.warning("fenics_dimension = {} and precice_dimension = {} do not match!".format(self._fenics_dimensions,
self._dimensions))
if self._can_apply_2d_3d_coupling():
logging.warning("2D-3D coupling will be applied. Z coordinates of all nodes will be set to zero.")
logger.warning("2D-3D coupling will be applied. Z coordinates of all nodes will be set to zero.")
else:
raise Exception("fenics_dimension = {}, precice_dimension = {}. "
"No proper treatment for dimensional mismatch is implemented. Aborting!".format(
Expand All @@ -689,22 +718,20 @@ def initialize(self, coupling_subdomain, mesh, read_field, write_field,
self._set_read_field(read_field)
self._set_write_field(write_field)
self._precice_tau = self._interface.initialize()

if self._interface.is_action_required(precice.action_write_initial_data()):
self._write_block_data()
self._interface.fulfilled_action(precice.action_write_initial_data())

self._interface.initialize_data()

if self._interface.is_read_data_available():
self._read_block_data()

if self._interface.is_action_required(precice.action_write_iteration_checkpoint()):
self._u_cp = u_n.copy(deepcopy=True)
self._t_cp = t
self._n_cp = n
self._interface.fulfilled_action(precice.action_write_iteration_checkpoint())

initial_state = SolverState(u_n, t, n)
self._save_solver_state_to_checkpoint(initial_state)

return self._precice_tau

def is_coupling_ongoing(self):
Expand Down
42 changes: 42 additions & 0 deletions fenicsadapter/solverstate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
class SolverState:
def __init__(self, u, t, n):
"""
Solver state consists of a value u, associated time t and the timestep n
:param u: value
:param t: time
:param n: timestep
"""
self.u = u
self.t = t
self.n = n

def get_state(self):
"""
returns the state variables value u, associated time t and timestep n
:return:
"""
return self.u, self.t, self.n

def update(self, other_state):
"""
updates the state using FEniCS assing function. self.u is updated.
This may also have an effect outside of this object! Compare to SolverState.copy(other_state).
:param other_state:
"""
self.u.assign(other_state.u)
self.t = other_state.t
self.n = other_state.n

def copy(self, other_state):
"""
copies a state using FEniCS copy function. self.u is overwritten.
This does not have an effect outside of this object! Compare to SolverState.update(other_state).
:param other_state:
"""
self.u = other_state.u.copy()
self.t = other_state.t
self.n = other_state.n

def print_state(self):
u, t, n = self.get_state()
return "u={u}, t={t}, n={n}".format(u=u, t=t, n=n)
3 changes: 3 additions & 0 deletions tests/MockedPrecice.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,6 @@ def configure(self, foo):

def get_dimensions(self):
raise Exception("not implemented")

def is_timestep_complete(self):
raise Exception("not implemented")
27 changes: 17 additions & 10 deletions tests/test_fenicsadapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ def assign(self, new_value):
"""
self.value = new_value.value

def copy(self):
returned_array = MockedArray()
returned_array.value = self.value
return returned_array

def value_rank(self):
return 0

Expand Down Expand Up @@ -53,7 +58,7 @@ def setUp(self):
warnings.simplefilter('ignore', category=ImportWarning)

def mock_the_adapter(self, precice):
from fenicsadapter.fenicsadapter import FunctionType
from fenicsadapter.fenicsadapter import FunctionType, SolverState
"""
We partially mock the fenicsadapter, since proper configuration and initialization of the adapter is not
necessary to test checkpointing.
Expand All @@ -65,9 +70,8 @@ def mock_the_adapter(self, precice):
precice._coupling_bc_expression = MagicMock()
precice._coupling_bc_expression.update_boundary_data = MagicMock()
# initialize checkpointing manually
precice._t_cp = self.t_cp_mocked
precice._u_cp = self.u_cp_mocked
precice._n_cp = self.n_cp_mocked
mocked_state = SolverState(self.u_cp_mocked, self.t_cp_mocked, self.n_cp_mocked)
precice._checkpoint.write(mocked_state)
precice._write_function_type = FunctionType.SCALAR
precice._read_function_type = FunctionType.SCALAR

Expand All @@ -92,6 +96,7 @@ def is_action_required_behavior(py_action):
Interface.get_data_id = MagicMock()
Interface.write_block_scalar_data = MagicMock()
Interface.read_block_scalar_data = MagicMock()
Interface.is_timestep_complete = MagicMock(return_value=True)
Interface.advance = MagicMock(return_value=self.dt)
Interface.fulfilled_action = MagicMock()

Expand All @@ -109,8 +114,8 @@ def is_action_required_behavior(py_action):
# we expect that self.u_n_mocked.value has been updated to self.u_np1_mocked.value
self.assertEqual(self.u_n_mocked.value, self.u_np1_mocked.value)

# we expect that precice._u_cp.value has been updated to value_u_np1
self.assertEqual(precice._u_cp.value, value_u_np1)
# we expect that the value of the checkpoint has been updated to value_u_np1
self.assertEqual(precice._checkpoint.get_state().u.value, value_u_np1)

def test_advance_rollback(self):
"""
Expand All @@ -133,6 +138,7 @@ def is_action_required_behavior(py_action):
Interface.get_data_id = MagicMock()
Interface.write_block_scalar_data = MagicMock()
Interface.read_block_scalar_data = MagicMock()
Interface.is_timestep_complete = MagicMock(return_value=False)
Interface.advance = MagicMock(return_value=self.dt)
Interface.fulfilled_action = MagicMock()

Expand All @@ -148,8 +154,8 @@ def is_action_required_behavior(py_action):
# we expect that self.u_n_mocked.value has been rolled back to self.u_cp_mocked.value
self.assertEqual(self.u_n_mocked.value, self.u_cp_mocked.value)

# we expect that precice._u_cp.value has not been updated
self.assertEqual(precice._u_cp.value, self.u_cp_mocked.value)
# we expect that precice._checkpoint.get_state().u has not been updated
self.assertEqual(precice._checkpoint.get_state().u.value, self.u_cp_mocked.value)

def test_advance_continue(self):
"""
Expand All @@ -171,6 +177,7 @@ def is_action_required_behavior(py_action):
Interface.get_dimensions = MagicMock()
Interface.get_mesh_id = MagicMock()
Interface.get_data_id = MagicMock()
Interface.is_timestep_complete = MagicMock(return_value=False)
Interface.write_block_scalar_data = MagicMock()
Interface.read_block_scalar_data = MagicMock()
Interface.advance = MagicMock(return_value=self.dt)
Expand All @@ -188,8 +195,8 @@ def is_action_required_behavior(py_action):
# we expect that self.u_n_mocked.value has been updated to self.u_np1_mocked.value
self.assertEqual(self.u_n_mocked.value, self.u_np1_mocked.value)

# we expect that precice._u_cp.value has not been updated
self.assertEqual(precice._u_cp.value, self.u_cp_mocked.value)
# we expect that the value of the checkpoint has not been updated
self.assertEqual(precice._checkpoint.get_state().u.value, self.u_cp_mocked.value)


@patch.dict('sys.modules', **{'dolfin': fake_dolfin, 'precice_future': tests.MockedPrecice})
Expand Down
4 changes: 4 additions & 0 deletions tests/test_write_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def dummy_set_mesh_vertices(mesh_id, positions):
Interface.initialize_data = MagicMock()
Interface.is_action_required = MagicMock(return_value=False)
Interface.fulfilled_action = MagicMock()
Interface.is_timestep_complete = MagicMock()
Interface.advance = MagicMock()
Interface.get_mesh_id = MagicMock()
Interface.get_data_id = MagicMock(return_value=15)
Expand Down Expand Up @@ -97,6 +98,7 @@ def dummy_set_mesh_vertices(mesh_id, positions):
Interface.initialize_data = MagicMock()
Interface.is_action_required = MagicMock(return_value=False)
Interface.fulfilled_action = MagicMock()
Interface.is_timestep_complete = MagicMock()
Interface.advance = MagicMock()
Interface.get_mesh_id = MagicMock()
Interface.get_data_id = MagicMock(return_value=15)
Expand Down Expand Up @@ -147,6 +149,7 @@ def dummy_set_mesh_vertices(mesh_id, positions):
Interface.initialize_data = MagicMock()
Interface.is_action_required = MagicMock(return_value=False)
Interface.fulfilled_action = MagicMock()
Interface.is_timestep_complete = MagicMock()
Interface.advance = MagicMock()
Interface.get_mesh_id = MagicMock()
Interface.get_data_id = MagicMock(return_value=15)
Expand Down Expand Up @@ -194,6 +197,7 @@ def dummy_set_mesh_vertices(mesh_id, positions):
Interface.initialize_data = MagicMock()
Interface.is_action_required = MagicMock(return_value=False)
Interface.fulfilled_action = MagicMock()
Interface.is_timestep_complete = MagicMock()
Interface.advance = MagicMock()
Interface.get_mesh_id = MagicMock()
Interface.get_data_id = MagicMock(return_value=15)
Expand Down

0 comments on commit 550ca38

Please sign in to comment.