Skip to content

Commit

Permalink
feat: clean the code (#48)
Browse files Browse the repository at this point in the history
* fix: update trainer and tuner

* chore: update pyproject to support python3.11

* fix: update pyproject

* fix: pyproject
  • Loading branch information
LongxingTan authored Oct 11, 2023
1 parent 6d19553 commit 4b179e8
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 9 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ check_dirs := tfts examples tests
# run checks on all files and potentially modifies some of them

style:
black --preview $(check_dirs)
black $(check_dirs)
isort $(check_dirs)
flake8
pre-commit run --all-files
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
tensorflow>=2.3.1
optuna>=2.0
pandas>=1.0
optuna>=2.3
pandas>=1.2
scikit-learn>0.23
joblib
matplotlib
9 changes: 7 additions & 2 deletions tfts/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def train_step(self, x_train, y_train):
return y_pred, loss

def valid_loop(self, valid_loader):
valid_loss = 0.0
valid_loss: float = 0.0
y_valid_trues, y_valid_preds = [], []

for valid_step, (x_valid, y_valid) in enumerate(valid_loader):
Expand Down Expand Up @@ -221,18 +221,21 @@ def __init__(
optimizer: tf.keras.optimizers = tf.keras.optimizers.Adam(0.003),
lr_scheduler: Optional[tf.keras.optimizers.Optimizer] = None,
strategy: Optional[tf.keras.optimizers.schedules.LearningRateSchedule] = None,
run_eagerly: bool = True,
**kwargs: Dict
) -> None:
"""
model: tf.keras.Model instance
loss: loss function
optimizer: tf.keras.Optimizer instance
run_eargely: it depends which one is much faster
"""
self.model = model
self.loss_fn = loss_fn
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.strategy = strategy
self.run_eargely = run_eagerly

for key, value in kwargs.items():
setattr(self, key, value)
Expand Down Expand Up @@ -298,7 +301,9 @@ def train(
self.model = self.model.build_model(inputs=inputs)

# print(self.model.summary())
self.model.compile(loss=self.loss_fn, optimizer=self.optimizer, metrics=callback_metrics, run_eagerly=False)
self.model.compile(
loss=self.loss_fn, optimizer=self.optimizer, metrics=callback_metrics, run_eagerly=self.run_eargely
)
if isinstance(train_dataset, (list, tuple)):
x_train, y_train = train_dataset

Expand Down
10 changes: 6 additions & 4 deletions tfts/tuner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""tfts auto tuner"""

from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union

import numpy as np
import optuna

Expand All @@ -15,8 +17,8 @@ class AutoTuner(object):
def __init__(self, use_model: str) -> None:
self.use_model = use_model

def generate_parameter(self):
pass
def generate_parameter(self) -> None:
return

def run(self, config, direction="maximize"):
pass
def run(self, config, direction: str = "maximize") -> None:
return

0 comments on commit 4b179e8

Please sign in to comment.