Skip to content

Commit

Permalink
feat: integration lightgbm (#746)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zeyi-Lin authored Dec 2, 2024
1 parent 016b2d2 commit c02471a
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 0 deletions.
56 changes: 56 additions & 0 deletions swanlab/integration/lightgbm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from typing import TYPE_CHECKING
import lightgbm # type: ignore
from lightgbm import Booster
import swanlab


if TYPE_CHECKING:
from typing import Any, Dict, List, NamedTuple, Tuple, Union

# Note: upstream lightgbm has this defined incorrectly
_EvalResultTuple = Union[
Tuple[str, str, float, bool], Tuple[str, str, float, bool, float]
]

class CallbackEnv(NamedTuple):
model: Any
params: Dict
iteration: int
begin_interation: int
end_iteration: int
evaluation_result_list: List[_EvalResultTuple]



class SwanLabCallback:
def __init__(self, log_params: bool = True) -> None:
self.order = 20
self.before_iteration = False
self.log_params = log_params

def _init(self, env: "CallbackEnv") -> None:
if self.log_params:
swanlab.config.update(env.params)

def __call__(self, env: "CallbackEnv") -> None:
if env.iteration == env.begin_iteration: # type: ignore
self._init(env)

for item in env.evaluation_result_list:
if len(item) == 4:
data_name, eval_name, result = item[:3]
swanlab.log(
{data_name + "_" + eval_name: result},
)
else:
data_name, eval_name = item[1].split()
res_mean = item[2]
res_stdv = item[4]
swanlab.log(
{
data_name + "_" + eval_name + "-mean": res_mean,
data_name + "_" + eval_name + "-stdv": res_stdv,
},
)

swanlab.log({"iteration": env.iteration})
66 changes: 66 additions & 0 deletions test/integration/lightgbm/train_lightgbm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import lightgbm as lgb
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import swanlab
from swanlab.integration.lightgbm import SwanLabCallback

# Step 1: Initialize swanlab
swanlab.init(project="lightgbm-example", name="breast-cancer-classification")

# Step 2: Load the dataset
data = load_breast_cancer()
X = data.data
y = data.target

# Step 3: Split the data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Step 4: Create LightGBM datasets
train_data = lgb.Dataset(X_train, label=y_train)
test_data = lgb.Dataset(X_test, label=y_test, reference=train_data)

# Step 5: Set parameters
params = {
'objective': 'binary',
'metric': 'binary_logloss',
'boosting_type': 'gbdt',
'num_leaves': 31,
'learning_rate': 0.05,
'feature_fraction': 0.9
}

# Step 6: Train the model with swanlab callback
num_round = 100
gbm = lgb.train(
params,
train_data,
num_round,
valid_sets=[test_data],
callbacks=[SwanLabCallback()]
)

# Step 8: Make predictions
y_pred = gbm.predict(X_test, num_iteration=gbm.best_iteration)
y_pred_binary = [1 if p >= 0.5 else 0 for p in y_pred]

# Step 9: Evaluate the model
accuracy = accuracy_score(y_test, y_pred_binary)
print(f"模型准确率: {accuracy:.4f}")
swanlab.log({"accuracy": accuracy})

# Step 10: Save the model locally
gbm.save_model('lightgbm_model.txt')

# Step 11: Load the model and predict again
bst_loaded = lgb.Booster(model_file='lightgbm_model.txt')
y_pred_loaded = bst_loaded.predict(X_test)
y_pred_binary_loaded = [1 if p >= 0.5 else 0 for p in y_pred_loaded]

# Step 12: Evaluate the loaded model
accuracy_loaded = accuracy_score(y_test, y_pred_binary_loaded)
print(f"加载模型后的准确率: {accuracy_loaded:.4f}")
swanlab.log({"accuracy_loaded": accuracy_loaded})

# Step 13: Finish the swanlab run
swanlab.finish()

0 comments on commit c02471a

Please sign in to comment.