From ccd28fee232f2c8025b53c104a0dee9d82baa539 Mon Sep 17 00:00:00 2001 From: InnopolisU Date: Tue, 27 Aug 2024 10:25:14 +0300 Subject: [PATCH] Fixed yolo pretraine errors --- .../ultralytics/ultralytics_adapter.py | 27 ++++++++++--------- .../base_checkpoint_handler.py | 5 +++- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/innofw/core/integrations/ultralytics/ultralytics_adapter.py b/innofw/core/integrations/ultralytics/ultralytics_adapter.py index f7541ac5..de1d6383 100644 --- a/innofw/core/integrations/ultralytics/ultralytics_adapter.py +++ b/innofw/core/integrations/ultralytics/ultralytics_adapter.py @@ -151,26 +151,27 @@ def train(self, data: UltralyticsDataModuleAdapter, ckpt_path=None): project="train", name=name,) - if ckpt_path is None: - self.opt.update( - device=self.device, - epochs=self.epochs, - imgsz=data.imgsz, - data=data.data, - workers=data.workers, - batch=data.batch_size, - ) - self.model.train(**self.opt, **self.hyp) - else: + if ckpt_path is not None: try: ckpt_path = TorchCheckpointHandler().convert_to_regular_ckpt( - ckpt_path, inplace=False, dst_path=None + ckpt_path, inplace=False, dst_path=None, set_epoch=0 ) self.opt.update(resume=str(ckpt_path)) - self.model.train(**self.opt, **self.hyp) + self.model.ckpt["epoch"] = 0 + self.model.ckpt_path = ckpt_path except Exception as e: print(e) + self.opt.update( + device=self.device, + epochs=self.epochs, + imgsz=data.imgsz, + data=data.data, + workers=data.workers, + batch=data.batch_size, + ) + self.model.train(**self.opt, **self.hyp) + self.update_checkpoints_path() def predict(self, data: UltralyticsDataModuleAdapter, ckpt_path=None): diff --git a/innofw/utils/checkpoint_utils/base_checkpoint_handler.py b/innofw/utils/checkpoint_utils/base_checkpoint_handler.py index 0173791a..6ab8d3b0 100644 --- a/innofw/utils/checkpoint_utils/base_checkpoint_handler.py +++ b/innofw/utils/checkpoint_utils/base_checkpoint_handler.py @@ -85,9 +85,12 @@ def convert_to_regular_ckpt( ckpt_path: Path, dst_path: Optional[Path] = None, inplace: bool = True, + set_epoch: int = -1 ) -> Path: model = self.load_model(None, ckpt_path) - + if set_epoch != -1: + if "epoch" in model.keys(): + model["epoch"] = set_epoch if inplace: tmp_path = Path(tempfile.mkdtemp()) / ckpt_path.name self.save_ckpt(model, tmp_path, None, wrap=False)