Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
Vytautas Jancauskas committed Feb 24, 2025
1 parent cebf357 commit 544262b
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/methane_super_emitters/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

@click.command()
@click.option("-i", "--input-dir", help="Data directory")
@click.option("-m", "--max-epochs", help="Maximum number of epochs", default=1)
def train_model(input_dir, max_epochs):
@click.option("-m", "--max-epochs", help="Maximum number of epochs", default=100)
@click.option("-n", "--n-trials", help="Number of trials or points to sample", default=200)
def optimize_model(input_dir, max_epochs, n_trials):
def objective(trial):
fields = ["methane", "u10", "v10", "qa"]
dropout_rate = trial.suggest_float("dropout", 0.1, 0.9)
Expand All @@ -18,10 +19,10 @@ def objective(trial):
trainer.fit(model=model, datamodule=datamodule)
return trainer.callback_metrics['val_acc'].item()
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=200)
study.optimize(objective, n_trials=n_trials)
df = study.trials_dataframe()
df.to_csv('opt_results.csv')
print("Best parameters:", study.best_params)

if __name__ == "__main__":
train_model()
optimize_model()
4 changes: 4 additions & 0 deletions test/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@
import lightning as L
from methane_super_emitters.model import SuperEmitterDetector
from methane_super_emitters.datamodule import TROPOMISuperEmitterDataModule
from methane_super_emitters.optimize import optimize_model

def test_model():
model = SuperEmitterDetector(fields=['methane', 'qa', 'u10', 'v10'])
datamodule = TROPOMISuperEmitterDataModule('./data/dataset', fields=['methane', 'qa', 'u10', 'v10'])
trainer = L.Trainer(max_epochs=10)
trainer.fit(model=model, datamodule=datamodule)
trainer.test(model=model, datamodule=datamodule)

def test_optimize():
optimize_model('data/dataset', 1, 10)

0 comments on commit 544262b

Please sign in to comment.