From 2f3b3d3af03ed8e9be487a698125fa1be86b23e2 Mon Sep 17 00:00:00 2001 From: pietrolesci Date: Wed, 28 Feb 2024 16:20:55 +0000 Subject: [PATCH] avoid configuring model twice --- energizer/estimator.py | 5 +++-- energizer/models.py | 9 ++++++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/energizer/estimator.py b/energizer/estimator.py index 398d90e..4ffc751 100644 --- a/energizer/estimator.py +++ b/energizer/estimator.py @@ -499,8 +499,9 @@ def test_epoch_end(self, output: list[BATCH_OUTPUT], metrics: METRIC | None) -> """ def configure_model(self) -> None: - with self.fabric.init_module(): - self.model.configure_model() + if not self.model.is_configured: + with self.fabric.init_module(): + self.model.configure_model() def configure_optimization_args( self, diff --git a/energizer/models.py b/energizer/models.py index 679f623..4291fa8 100644 --- a/energizer/models.py +++ b/energizer/models.py @@ -18,6 +18,7 @@ class Model(ABC): _model_instance: torch.nn.Module | None = None + _is_configured: bool = False @property def model_instance(self) -> torch.nn.Module: @@ -32,6 +33,10 @@ def summary(self) -> str: def configure_model(self, *args, **kwargs) -> None: ... + @property + def is_configured(self) -> bool: + return self._is_configured + class TorchModel(Model): def __init__(self, model: torch.nn.Module) -> None: @@ -39,7 +44,7 @@ def __init__(self, model: torch.nn.Module) -> None: self._model_instance = model def configure_model(self, *args, **kwargs) -> None: - pass + self._is_configured = True class HFModel(Model): @@ -147,6 +152,8 @@ def configure_model(self) -> None: if self._convert_to_bettertransformer: self.convert_to_bettertransformer() + self._is_configured = True + def convert_to_bettertransformer(self) -> None: assert self._model_instance is not None self._model_instance = self._model_instance.to_bettertransformer()