Skip to content

Commit

Permalink
Making blocks more generic
Browse files Browse the repository at this point in the history
  • Loading branch information
jezsadler committed Sep 28, 2024
1 parent 164b179 commit 6dfd2c0
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 15 deletions.
26 changes: 14 additions & 12 deletions src/omlt/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,7 @@ class is used in combination with a formulation object to construct the
from omlt.base import DEFAULT_MODELING_LANGUAGE, OmltVarFactory


@declare_custom_block(name="OmltBlock")
class OmltBlockData(BlockData):
def __init__(self, component):
super().__init__(component)
self.__formulation = None
self.__input_indexes = None
self.__output_indexes = None
self._format = DEFAULT_MODELING_LANGUAGE

def set_format(self, lang):
self._format = lang

class OmltBlockCore:
def _setup_inputs_outputs(self, *, input_indexes, output_indexes):
"""Setup inputs and outputs.
Expand Down Expand Up @@ -117,3 +106,16 @@ def build_formulation(self, formulation, lang=None):

# tell the formulation object to construct the necessary models
self.__formulation._build_formulation()


@declare_custom_block(name="OmltBlock")
class OmltBlockData(_BlockData, OmltBlockCore):
def __init__(self, component):
super().__init__(component)
self.__formulation = None
self.__input_indexes = None
self.__output_indexes = None
self._format = DEFAULT_MODELING_LANGUAGE

def set_format(self, lang):
self._format = lang
6 changes: 3 additions & 3 deletions tests/base/test_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ def test_block():
formulation = DummyFormulation()
m.b.build_formulation(formulation, lang="pyomo")

assert m.b._OmltBlockData__formulation is formulation
assert list(m.b.inputs) == ["A", "C", "D"]
assert list(m.b.outputs) == [(0, 0), (0, 1), (1, 0), (1, 1)]
assert m.b._OmltBlockCore__formulation is formulation
assert [k for k in m.b.inputs] == ["A", "C", "D"]
assert [k for k in m.b.outputs] == [(0, 0), (0, 1), (1, 0), (1, 1)]


def test_input_output_auto_creation():
Expand Down

0 comments on commit 6dfd2c0

Please sign in to comment.