diff --git a/Makefile b/Makefile index 4a7f33b..aa7bac7 100644 --- a/Makefile +++ b/Makefile @@ -5,7 +5,7 @@ check_dirs := tfts examples tests # run checks on all files and potentially modifies some of them style: - black --preview $(check_dirs) + black $(check_dirs) isort $(check_dirs) flake8 pre-commit run --all-files diff --git a/requirements.txt b/requirements.txt index e64acd4..37d5bb4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ tensorflow>=2.3.1 -optuna>=2.0 -pandas>=1.0 +optuna>=2.3 +pandas>=1.2 scikit-learn>0.23 joblib matplotlib diff --git a/tfts/trainer.py b/tfts/trainer.py index 993baa4..d884298 100644 --- a/tfts/trainer.py +++ b/tfts/trainer.py @@ -166,7 +166,7 @@ def train_step(self, x_train, y_train): return y_pred, loss def valid_loop(self, valid_loader): - valid_loss = 0.0 + valid_loss: float = 0.0 y_valid_trues, y_valid_preds = [], [] for valid_step, (x_valid, y_valid) in enumerate(valid_loader): @@ -221,18 +221,21 @@ def __init__( optimizer: tf.keras.optimizers = tf.keras.optimizers.Adam(0.003), lr_scheduler: Optional[tf.keras.optimizers.Optimizer] = None, strategy: Optional[tf.keras.optimizers.schedules.LearningRateSchedule] = None, + run_eagerly: bool = True, **kwargs: Dict ) -> None: """ model: tf.keras.Model instance loss: loss function optimizer: tf.keras.Optimizer instance + run_eargely: it depends which one is much faster """ self.model = model self.loss_fn = loss_fn self.optimizer = optimizer self.lr_scheduler = lr_scheduler self.strategy = strategy + self.run_eargely = run_eagerly for key, value in kwargs.items(): setattr(self, key, value) @@ -298,7 +301,9 @@ def train( self.model = self.model.build_model(inputs=inputs) # print(self.model.summary()) - self.model.compile(loss=self.loss_fn, optimizer=self.optimizer, metrics=callback_metrics, run_eagerly=False) + self.model.compile( + loss=self.loss_fn, optimizer=self.optimizer, metrics=callback_metrics, run_eagerly=self.run_eargely + ) if isinstance(train_dataset, (list, tuple)): x_train, y_train = train_dataset diff --git a/tfts/tuner.py b/tfts/tuner.py index 83368b0..f61eb41 100644 --- a/tfts/tuner.py +++ b/tfts/tuner.py @@ -1,5 +1,7 @@ """tfts auto tuner""" +from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union + import numpy as np import optuna @@ -15,8 +17,8 @@ class AutoTuner(object): def __init__(self, use_model: str) -> None: self.use_model = use_model - def generate_parameter(self): - pass + def generate_parameter(self) -> None: + return - def run(self, config, direction="maximize"): - pass + def run(self, config, direction: str = "maximize") -> None: + return