Skip to content

Commit

Permalink
avoid configuring model twice
Browse files Browse the repository at this point in the history
  • Loading branch information
pietrolesci committed Feb 28, 2024
1 parent f434049 commit 2f3b3d3
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
5 changes: 3 additions & 2 deletions energizer/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,8 +499,9 @@ def test_epoch_end(self, output: list[BATCH_OUTPUT], metrics: METRIC | None) ->
"""

def configure_model(self) -> None:
with self.fabric.init_module():
self.model.configure_model()
if not self.model.is_configured:
with self.fabric.init_module():
self.model.configure_model()

def configure_optimization_args(
self,
Expand Down
9 changes: 8 additions & 1 deletion energizer/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

class Model(ABC):
_model_instance: torch.nn.Module | None = None
_is_configured: bool = False

@property
def model_instance(self) -> torch.nn.Module:
Expand All @@ -32,14 +33,18 @@ def summary(self) -> str:
def configure_model(self, *args, **kwargs) -> None:
...

@property
def is_configured(self) -> bool:
return self._is_configured


class TorchModel(Model):
def __init__(self, model: torch.nn.Module) -> None:
super().__init__()
self._model_instance = model

def configure_model(self, *args, **kwargs) -> None:
pass
self._is_configured = True


class HFModel(Model):
Expand Down Expand Up @@ -147,6 +152,8 @@ def configure_model(self) -> None:
if self._convert_to_bettertransformer:
self.convert_to_bettertransformer()

self._is_configured = True

def convert_to_bettertransformer(self) -> None:
assert self._model_instance is not None
self._model_instance = self._model_instance.to_bettertransformer()
Expand Down

0 comments on commit 2f3b3d3

Please sign in to comment.