diff --git a/gosdt/gosdt-imbalance-tutorial.ipynb b/gosdt/gosdt-imbalance-tutorial.ipynb new file mode 100644 index 0000000..0741578 --- /dev/null +++ b/gosdt/gosdt-imbalance-tutorial.ipynb @@ -0,0 +1,1123 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "import time\n", + "import pathlib\n", + "from sklearn.ensemble import GradientBoostingClassifier\n", + "from sklearn.model_selection import train_test_split\n", + "from gosdt.model.threshold_guess import compute_thresholds, cut\n", + "from gosdt.model.gosdt import GOSDT" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Below we show examples on how to run f1, auc convex hull, and partial auc convex hull using GOSDT. Note that optimization for these objectives are slower than optimizing accuracy and balanced accuracy, since dynamic programming can not be applied.**\n", + "\n", + "\n", + "### Example 1 (Monk 1)\n", + "\n", + "We first show examples using Monk 1 dataset. " + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
head_shape_roundhead_shape_squarehead_shape_octagonbody_shape_roundbody_shape_squarebody_shape_octagonis_smiling_yesis_smiling_noholding_swordholding_balloonholding_flagjacket_color_redjacket_color_yellowjacket_color_greenjacket_color_bluehas_tie_yeshas_tie_no
010010010100001010
110010010100001001
210010010001010010
310010010001001001
410010001100010010
\n", + "
" + ], + "text/plain": [ + " head_shape_round head_shape_square head_shape_octagon body_shape_round \\\n", + "0 1 0 0 1 \n", + "1 1 0 0 1 \n", + "2 1 0 0 1 \n", + "3 1 0 0 1 \n", + "4 1 0 0 1 \n", + "\n", + " body_shape_square body_shape_octagon is_smiling_yes is_smiling_no \\\n", + "0 0 0 1 0 \n", + "1 0 0 1 0 \n", + "2 0 0 1 0 \n", + "3 0 0 1 0 \n", + "4 0 0 0 1 \n", + "\n", + " holding_sword holding_balloon holding_flag jacket_color_red \\\n", + "0 1 0 0 0 \n", + "1 1 0 0 0 \n", + "2 0 0 1 0 \n", + "3 0 0 1 0 \n", + "4 1 0 0 0 \n", + "\n", + " jacket_color_yellow jacket_color_green jacket_color_blue has_tie_yes \\\n", + "0 0 1 0 1 \n", + "1 0 1 0 0 \n", + "2 1 0 0 1 \n", + "3 0 1 0 0 \n", + "4 1 0 0 1 \n", + "\n", + " has_tie_no \n", + "0 0 \n", + "1 1 \n", + "2 0 \n", + "3 1 \n", + "4 0 " + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from sklearn.preprocessing import OneHotEncoder\n", + "df = pd.read_csv(\n", + " \"https://archive.ics.uci.edu/ml/machine-learning-databases/monks-problems/monks-1.train\", sep=\" \", header=None)\n", + "df = df.iloc[:,1:-1]\n", + "\n", + "# assign column names \n", + "df.columns = [\"label\", \"head_shape\", \"body_shape\", \"is_smiling\", \"holding\", \"jacket_color\", \"has_tie\"]\n", + "X = df.iloc[:,1:]\n", + "y = df.iloc[:,0]\n", + "y = pd.DataFrame(y)\n", + "\n", + "# Encode the categorical features to binary features\n", + "enc = OneHotEncoder()\n", + "enc.fit(X)\n", + "X = pd.DataFrame(enc.transform(X).toarray(), dtype=int)\n", + "X.columns = [\"head_shape_round\", \"head_shape_square\", \"head_shape_octagon\", \n", + " \"body_shape_round\", \"body_shape_square\", \"body_shape_octagon\",\n", + " \"is_smiling_yes\", \"is_smiling_no\", \n", + " \"holding_sword\", \"holding_balloon\", \"holding_flag\",\n", + " \"jacket_color_red\", \"jacket_color_yellow\", \"jacket_color_green\", \"jacket_color_blue\",\n", + " \"has_tie_yes\", \"has_tie_no\"\n", + " ]\n", + "X.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Below is an example to set up the configuration. For ``accuracy``, ``balanced accuracy``, ``weighted accuracy``, we suggest using C++ version, since it uses the dynamic programming and can be much faster. \n", + "\n", + "**F1 score**: \n", + "{ \"objective\": \"f1\", \"w\": 0.9, \"theta\": None}. You need to specify hyperparameter $w \\in (0,1)$. This is because optimizing F-score is much harder than other arbitrary monotonic losses.Thus,we simplify the labeling step by incorporating a parameter $w$ at each leaf node. \n", + "\n", + "**AUC convex hull**: { \"objective\": \"auc\", \"w\": None, \"theta\": None } \n", + "\n", + "**Partial AUC convex hull**: { \"objective\": \"pauc\", \"w\": None, \"theta\": 0.1 } You need to specify hyperparameter $\\theta \\in (0,1)$. $\\theta$ is used to specify the left most part of the ROC curve. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### (a) F1 objective" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "config={\n", + " \"regularization\": 0.01,\n", + " \"objective\": \"f1\",\n", + " \"w\": 0.9,\n", + " \"time_limit\": 60\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "model = GOSDT(config)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loss function: f1\n", + "lambda: 0.01\n", + "COUNT_UNIQLEAVES: 2185\n", + "COUNT_LEAFLOOKUPS: 80609\n", + "total time: 9.146592378616333\n", + "leaves: [(1,), (-9, -1, 2), (-1, 2, 9), (-13, -3, -2, -1), (-3, -2, -1, 13), (-12, -2, -1, 3), (-2, -1, 3, 12)]\n", + "num_captured: [29, 31, 8, 20, 12, 11, 13]\n", + "prediction: [1, 0, 1, 0, 1, 0, 1]\n", + "Objective: 0.07\n", + "f1 : 1.0\n", + "COUNT of the best tree: 3476\n", + "time when the best tree is achieved: 0.23038196563720703\n", + "TOTAL COUNT: 176434\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.fit(X, y)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Decode leaves" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'feature': 11,\n", + " 'name': 'jacket_color_red',\n", + " 'reference': 1,\n", + " 'relation': '==',\n", + " 'true': {'complexity': 0.01, 'loss': 0.0, 'name': 'class', 'prediction': 1},\n", + " 'false': {'feature': 0,\n", + " 'name': 'head_shape_round',\n", + " 'reference': 1,\n", + " 'relation': '==',\n", + " 'true': {'feature': 3,\n", + " 'name': 'body_shape_round',\n", + " 'reference': 1,\n", + " 'relation': '==',\n", + " 'true': {'complexity': 0.01, 'loss': 0.0, 'name': 'class', 'prediction': 1},\n", + " 'false': {'complexity': 0.01,\n", + " 'loss': 0.0,\n", + " 'name': 'class',\n", + " 'prediction': 0}},\n", + " 'false': {'feature': 2,\n", + " 'name': 'head_shape_octagon',\n", + " 'reference': 1,\n", + " 'relation': '==',\n", + " 'true': {'feature': 5,\n", + " 'name': 'body_shape_octagon',\n", + " 'reference': 1,\n", + " 'relation': '==',\n", + " 'true': {'complexity': 0.01,\n", + " 'loss': 0.0,\n", + " 'name': 'class',\n", + " 'prediction': 1},\n", + " 'false': {'complexity': 0.01,\n", + " 'loss': 0.0,\n", + " 'name': 'class',\n", + " 'prediction': 0}},\n", + " 'false': {'feature': 4,\n", + " 'name': 'body_shape_square',\n", + " 'reference': 1,\n", + " 'relation': '==',\n", + " 'true': {'complexity': 0.01,\n", + " 'loss': 0.0,\n", + " 'name': 'class',\n", + " 'prediction': 1},\n", + " 'false': {'complexity': 0.01,\n", + " 'loss': 0.0,\n", + " 'name': 'class',\n", + " 'prediction': 0}}}}}" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.tree.source" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can now explore more about the obtained tree. " + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model training time: 9.1525593\n", + "Training accuracy: 1.0\n", + "# of leaves: 7\n" + ] + } + ], + "source": [ + "acc = model.score(X, y) # calculate the accuracy\n", + "n_leaves = model.leaves()\n", + "n_nodes = model.nodes()\n", + "\n", + "print(\"Model training time: {}\".format(model.duration))\n", + "print(\"Training accuracy: {}\".format(acc))\n", + "print(\"# of leaves: {}\".format(n_leaves))" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "f1 score: 1.0\n" + ] + } + ], + "source": [ + "# Note that if you want to test f1 metric, tree.score() won't work. \n", + "from sklearn.metrics import f1_score\n", + "y_hat = model.predict(X)\n", + "print(\"f1 score:\", f1_score(y, y_hat))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### (b) AUC convex hull objective" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loss function: auc\n", + "lambda: 0.01\n", + "COUNT_UNIQLEAVES: 3087\n", + "COUNT_LEAFLOOKUPS: 99455\n", + "total time: 56.81161284446716\n", + "leaves: [(1,), (-12, -1, 3), (-1, 3, 12), (-14, -9, -3, -1), (-14, -3, -1, 9), (-13, -3, -1, 14), (-3, -1, 13, 14)]\n", + "num_captured: [29, 11, 13, 31, 8, 20, 12]\n", + "prediction: [1, 1, 1, 1, 1, 1, 1]\n", + "Objective: 0.07\n", + "auc : 1.0\n", + "COUNT of the best tree: 52236\n", + "time when the best tree is achieved: 14.399035930633545\n", + "TOTAL COUNT: 247840\n", + "{'feature': 11, 'name': 'jacket_color_red', 'reference': 1, 'relation': '==', 'true': {'complexity': 0.01, 'loss': 0.0, 'name': 'class', 'prediction': 1}, 'false': {'feature': 2, 'name': 'head_shape_octagon', 'reference': 1, 'relation': '==', 'true': {'feature': 5, 'name': 'body_shape_octagon', 'reference': 1, 'relation': '==', 'true': {'complexity': 0.01, 'loss': 0.0, 'name': 'class', 'prediction': 1}, 'false': {'complexity': 0.01, 'loss': 0.08870967741935482, 'name': 'class', 'prediction': 1}}, 'false': {'feature': 1, 'name': 'head_shape_square', 'reference': 1, 'relation': '==', 'true': {'feature': 4, 'name': 'body_shape_square', 'reference': 1, 'relation': '==', 'true': {'complexity': 0.01, 'loss': 0.0, 'name': 'class', 'prediction': 1}, 'false': {'complexity': 0.01, 'loss': 0.16129032258064507, 'name': 'class', 'prediction': 1}}, 'false': {'feature': 3, 'name': 'body_shape_round', 'reference': 1, 'relation': '==', 'true': {'complexity': 0.01, 'loss': 0.0, 'name': 'class', 'prediction': 1}, 'false': {'complexity': 0.01, 'loss': 0.24999999999999983, 'name': 'class', 'prediction': 1}}}}}\n" + ] + } + ], + "source": [ + "config={\n", + " \"regularization\": 0.01,\n", + " \"objective\": \"auc\",\n", + " \"time_limit\": 60\n", + "}\n", + "model = GOSDT(config)\n", + "model.fit(X, y)\n", + "print(model.tree.source)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The algorithm can finish running within 60 seconds time limit. You may find one thing above is interesting. Prediction vector contains all ones. This is because for ``auc`` and ``pauc`` metric, we are optimizing the convex hull. Code below can be used to draw ROC curve for ``auc`` and ``pauc`` objectives. " + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "def helper(node, X, y):\n", + " node[\"pos_neg\"] = [y.sum()[0], y.shape[0]-y.sum()[0], y.sum()[0]/y.shape[0]]\n", + " \n", + " if \"relation\" in node:\n", + " if node[\"relation\"] == \"==\":\n", + " X_left = X[X[node[\"name\"]] == node[\"reference\"]]\n", + " y_left = y[X[node[\"name\"]] == node[\"reference\"]]\n", + " X_right = X[X[node[\"name\"]] != node[\"reference\"]]\n", + " y_right = y[X[node[\"name\"]] != node[\"reference\"]]\n", + " else:\n", + " X_left, y_left, X_right, y_right = None, None, None, None\n", + " return node, X_left, y_left, X_right, y_right\n", + " \n", + "\n", + "def add_prob(tree, X, y):\n", + " pos_neg = []\n", + " nodes = [tree]\n", + " data = [[X, y]]\n", + " while len(nodes) > 0:\n", + " node = nodes.pop()\n", + " data_tmp = data.pop()\n", + " X = data_tmp[0]\n", + " y = data_tmp[1]\n", + " node, X_left, y_left, X_right, y_right = helper(node, X, y)\n", + " if \"prediction\" not in node:\n", + " nodes.append(node[\"true\"])\n", + " data.append([X_left, y_left])\n", + " nodes.append(node[\"false\"])\n", + " data.append([X_right, y_right])\n", + " else:\n", + " pos_neg.append(node[\"pos_neg\"])\n", + " return pos_neg\n", + "\n", + "def plot_roc(tree, X, y):\n", + " # plot roc curve for auc and pauc objective\n", + " pos_neg = add_prob(tree, X, y)\n", + " pos_neg = np.array(pos_neg).T\n", + " \n", + " P = np.count_nonzero(y) # positive samples\n", + " N = len(y) - P\n", + " \n", + " pos_neg = pos_neg[:,np.argsort(pos_neg[2,])]\n", + " pos_neg = np.flip(pos_neg, axis=1)\n", + " pos_neg = np.cumsum(pos_neg,axis=1)\n", + " init = np.array([[0], [0], [0]])\n", + " pos_neg = np.append(init, pos_neg, axis=1) \n", + " tp = pos_neg[0, :]\n", + " fp = pos_neg[1, :]\n", + " out = 0.5*sum([(pos_neg[0,i]+pos_neg[0,i-1])*(pos_neg[1,i]-pos_neg[1,i-1])/(P*N) for i in range(1,len(tp))])\n", + " print(\"area under roc curve:\", out)\n", + " plt.plot([f/N for f in fp], [f/P for f in tp], 'go-', linewidth=2)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "area under roc curve: 1.0\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAQi0lEQVR4nO3dYWhd533H8e8/sbMilqpjVmHElpQyB2riQYrIMgprirvhBGy/6YqNQtcRKtotHaFlkKGRtil6sZatpsxbe9lK16ImTfuilYNLYF5CR6m7KKSNawcPz7UdkdKoXaY3oo3D/ntxb8SVLPke2Vf36jz6fkD4PM95fM//8bn66ficc3UiM5Ek1d9N/S5AktQdBrokFcJAl6RCGOiSVAgDXZIKsa1fG96xY0eOjo72a/OSVEvPP//8LzJzaLV1fQv00dFRZmdn+7V5SaqliLi01jpPuUhSIQx0SSqEgS5JhTDQJakQBrokFaJjoEfElyPi1Yj4yRrrIyK+EBHnI+LFiHhX98tsmj49zejRUW769E2MHh1l+vT0Rm2qp0qdV924H7TRNvo9VuW2xa8A/wB8dY319wG7W1+/D/xT68+umj49zcTxCRavLAJwaeESE8cnABjfO97tzfVMqfOqG/eDNlov3mNR5dfnRsQo8FRm3rnKui8Bz2bm4632OeDezPzZtV5zbGws13Mf+ujRUS4trHn7pSTV0sjgCBcfvlh5fEQ8n5ljq63rxjn024CX29pzrb7VCpmIiNmImJ2fn1/XRi4vXL7+CiVpk+pmtnXjk6KxSt+qh/2Z2QAa0DxCX89GhgeHVz1CX+9Pt81mrf951H1edeN+0EZb6z02PDjctW104wh9DtjV1t4JvNKF111mat8UA9sHlvUNbB9gat9UtzfVU6XOq27cD9povXiPdSPQZ4APtu52uQdY6HT+/HqM7x2ncaCx1B4ZHKFxoFH7C1ZvzmtkcIQgiplX3bgftNF68R7reFE0Ih4H7gV2AD8HPglsB8jML0ZE0LwLZj+wCPxZZna82rnei6JL9Xy6eYYnP+mzUCVtPde6KNrxHHpmHumwPoG/uM7aJEld4idFJakQBrokFcJAl6RCGOiSVAgDXZIKYaBLUiEMdEkqhIEuSYUw0CWpEAa6JBXCQJekQhjoklQIA12SCmGgS1IhDHRJKoSBLkmFMNAlqRC1CvTp09NLy6NHR5e1JWmrq02gT5+eZuL4xFL70sIlJo5PGOqS1FKbQJ88OcnilcVlfYtXFpk8OdmniiRpc6lNoF9euLyufknaamoT6MODw+vql6StpjaBPrVvioHtA8v6BrYPMLVvqk8VSdLmUptAH987TuNAY6k9MjhC40CD8b3jfaxKkjaP2gQ6sCy8Lz580TCXpDa1CnRJ0toMdEkqhIEuSYUw0CWpEAa6JBXCQJekQlQK9IjYHxHnIuJ8RDyyyvrhiHgmIl6IiBcj4v7ulypJupaOgR4RNwPHgPuAPcCRiNizYtjfAE9m5l3AYeAfu12oJOnaqhyh3w2cz8wLmfk68ARwaMWYBN7aWh4EXuleiZKkKqoE+m3Ay23tuVZfu08BD0TEHHAC+NhqLxQRExExGxGz8/Pz11GuJGktVQI9VunLFe0jwFcycydwP/C1iLjqtTOzkZljmTk2NDS0/molSWuqEuhzwK629k6uPqXyIPAkQGb+AHgLsKMbBUqSqqkS6M8BuyPi9oi4heZFz5kVYy4D+wAi4p00A91zKpLUQx0DPTPfAB4CngZeonk3y5mIeCwiDraGfQL4cET8GHgc+FBmrjwtI0naQNuqDMrMEzQvdrb3Pdq2fBZ4d3dLkySth58UlaRCGOiSVAgDXZIKYaBLUiEMdEkqhIEuSYUw0CWpEAa6JBXCQJekQhjoklQIA12SCmGgS1IhDHRJKoSBLkmFMNAlqRC1CvTp09NLy6NHR5e1JWmrq02gT5+eZuL4xFL70sIlJo5PGOqS1FKbQJ88OcnilcVlfYtXFpk8OdmniiRpc6lNoF9euLyufknaamoT6MODw+vql6StpjaBPrVvioHtA8v6BrYPMLVvqk8VSdLmUptAH987TuNAY6k9MjhC40CD8b3jfaxKkjaP2gQ6sCy8Lz580TCXpDa1CnRJ0toMdEkqhIEuSYUw0CWpEAa6JBXCQJekQhjoklSISoEeEfsj4lxEnI+IR9YY84GIOBsRZyLi690tU5LUybZOAyLiZuAY8EfAHPBcRMxk5tm2MbuBvwbenZmvRcTbN6pgSdLqqhyh3w2cz8wLmfk68ARwaMWYDwPHMvM1gMx8tbtlSpI6qRLotwEvt7XnWn3t7gDuiIjvR8SpiNi/2gtFxEREzEbE7Pz8/PVVLElaVZVAj1X6ckV7G7AbuBc4AvxzRLztqr+U2cjMscwcGxoaWm+tkqRrqBLoc8CutvZO4JVVxnwnM69k5k+BczQDXpLUI1UC/Tlgd0TcHhG3AIeBmRVjvg28FyAidtA8BXOhm4VKkq6tY6Bn5hvAQ8DTwEvAk5l5JiIei4iDrWFPA7+MiLPAM8BfZeYvN6poSdLVOt62CJCZJ4ATK/oebVtO4OOtL0lSH/hJUUkqhIEuSYUw0CWpEAa6JBXCQJekQhjoklQIA12SCmGgS1IhDHRJKoSBLkmFMNAlqRAGuiQVwkCXpEIY6JJUCANdkgphoEtSIQx0SSpErQJ9+vT00vLo0dFlbUna6moT6NOnp5k4PrHUvrRwiYnjE4a6JLXUJtAnT06yeGVxWd/ilUUmT072qSJJ2lxqE+iXFy6vq1+StpraBPrw4PC6+iVpq6lNoE/tm2Jg+8CyvoHtA0ztm+pTRZK0udQm0Mf3jtM40FhqjwyO0DjQYHzveB+rkqTNozaBDiwL74sPXzTMJalNrQJdkrQ2A12SCmGgS1IhDHRJKoSBLkmFMNAlqRCVAj0i9kfEuYg4HxGPXGPc+yMiI2KseyVKkqroGOgRcTNwDLgP2AMciYg9q4y7FfhL4IfdLlKS1FmVI/S7gfOZeSEzXweeAA6tMu4zwGeBX3WxPklSRVUC/Tbg5bb2XKtvSUTcBezKzKeu9UIRMRERsxExOz8/v+5iJUlrqxLosUpfLq2MuAn4PPCJTi+UmY3MHMvMsaGhoepVSpI6qhLoc8CutvZO4JW29q3AncCzEXERuAeY8cKoJPVWlUB/DtgdEbdHxC3AYWDmzZWZuZCZOzJzNDNHgVPAwcyc3ZCKJUmr6hjomfkG8BDwNPAS8GRmnomIxyLi4EYXKEmqZluVQZl5Ajixou/RNcbee+NlSZLWy0+KSlIhDHRJKoSBLkmFMNAlqRAGuiQVwkCXpEIY6JJUCANdkgphoEtSIQx0SSqEgS5JhTDQJakQBrokFcJAl6RCGOiSVAgDXZIKYaBLUiEMdEkqhIEuSYUw0CWpEAa6JBXCQJekQhjoklQIA12SCmGgS1IhDHRJKoSBLkmFMNAlqRAGuiQVwkCXpEJUCvSI2B8R5yLifEQ8ssr6j0fE2Yh4MSJORsRI90uVJF1Lx0CPiJuBY8B9wB7gSETsWTHsBWAsM38P+Bbw2W4XKkm6tipH6HcD5zPzQma+DjwBHGofkJnPZOZiq3kK2NndMiVJnVQJ9NuAl9vac62+tTwIfHe1FRExERGzETE7Pz9fvUpJUkdVAj1W6ctVB0Y8AIwBn1ttfWY2MnMsM8eGhoaqVylJ6mhbhTFzwK629k7glZWDIuJ9wCTwnsz8dXfKkyRVVeUI/Tlgd0TcHhG3AIeBmfYBEXEX8CXgYGa+2v0yJUmddAz0zHwDeAh4GngJeDIzz0TEYxFxsDXsc8BvAt+MiB9FxMwaLydJ2iBVTrmQmSeAEyv6Hm1bfl+X65IkrZOfFJWkQhjoklQIA12SCmGgS1IhDHRJKoSBLkmFMNAlqRAGuiQVwkCXpEIY6JJUCANdkgphoEtSIQx0SSqEgS5JhTDQJakQBrokFcJAl6RCGOiSVAgDXZIKYaBLUiEMdEkqhIEuSYUw0CWpEAa6JBXCQJekQhjoklQIA12SCmGgS1IhDHRJKoSBLkmFMNAlqRCVAj0i9kfEuYg4HxGPrLL+NyLiG631P4yI0W4XCjB9enppefTo6LK2JG11HQM9Im4GjgH3AXuAIxGxZ8WwB4HXMvN3gc8Df9vtQqdPTzNxfGKpfWnhEhPHJwx1SWqpcoR+N3A+My9k5uvAE8ChFWMOAf/aWv4WsC8iontlwuTJSRavLC7rW7yyyOTJyW5uRpJqq0qg3wa83Naea/WtOiYz3wAWgN9e+UIRMRERsxExOz8/v65CLy9cXle/JG01VQJ9tSPtvI4xZGYjM8cyc2xoaKhKfUuGB4fX1S9JW02VQJ8DdrW1dwKvrDUmIrYBg8D/dKPAN03tm2Jg+8CyvoHtA0ztm+rmZiSptqoE+nPA7oi4PSJuAQ4DMyvGzAB/2lp+P/DvmXnVEfqNGN87TuNAg5HBEYJgZHCExoEG43vHu7kZSaqtbZ0GZOYbEfEQ8DRwM/DlzDwTEY8Bs5k5A/wL8LWIOE/zyPzwRhQ7vnfcAJekNXQMdIDMPAGcWNH3aNvyr4A/6W5pkqT18JOiklQIA12SCmGgS1IhDHRJKkR0+e7C6huOmAcuXedf3wH8oovl1IFz3hqc89ZwI3MeycxVP5nZt0C/ERExm5lj/a6jl5zz1uCct4aNmrOnXCSpEAa6JBWiroHe6HcBfeCctwbnvDVsyJxreQ5dknS1uh6hS5JWMNAlqRCbOtA3y8Ope6nCnD8eEWcj4sWIOBkRI/2os5s6zblt3PsjIiOi9re4VZlzRHygta/PRMTXe11jt1V4bw9HxDMR8ULr/X1/P+rsloj4ckS8GhE/WWN9RMQXWv8eL0bEu254o5m5Kb9o/qre/wbeAdwC/BjYs2LMnwNfbC0fBr7R77p7MOf3AgOt5Y9uhTm3xt0KfA84BYz1u+4e7OfdwAvAb7Xab+933T2YcwP4aGt5D3Cx33Xf4Jz/EHgX8JM11t8PfJfmE9/uAX54o9vczEfom+Lh1D3Wcc6Z+Uxmvvm07FM0nyBVZ1X2M8BngM8Cv+plcRukypw/DBzLzNcAMvPVHtfYbVXmnMBbW8uDXP1ktFrJzO9x7Se3HQK+mk2ngLdFxO/cyDY3c6B37eHUNVJlzu0epPkTvs46zjki7gJ2ZeZTvSxsA1XZz3cAd0TE9yPiVETs71l1G6PKnD8FPBARczSfv/Cx3pTWN+v9fu+o0gMu+qRrD6eukcrziYgHgDHgPRta0ca75pwj4ibg88CHelVQD1TZz9tonna5l+b/wv4jIu7MzP/d4No2SpU5HwG+kpl/FxF/QPMpaHdm5v9tfHl90fX82sxH6Jvi4dQ9VmXORMT7gEngYGb+uke1bZROc74VuBN4NiIu0jzXOFPzC6NV39vfycwrmflT4BzNgK+rKnN+EHgSIDN/ALyF5i+xKlWl7/f12MyBvikeTt1jHefcOv3wJZphXvfzqtBhzpm5kJk7MnM0M0dpXjc4mJmz/Sm3K6q8t79N8wI4EbGD5imYCz2tsruqzPkysA8gIt5JM9Dne1plb80AH2zd7XIPsJCZP7uhV+z3leAOV4nvB/6L5tXxyVbfYzS/oaG5w78JnAf+E3hHv2vuwZz/Dfg58KPW10y/a97oOa8Y+yw1v8ul4n4O4O+Bs8Bp4HC/a+7BnPcA36d5B8yPgD/ud803ON/HgZ8BV2gejT8IfAT4SNs+Ptb69zjdjfe1H/2XpEJs5lMukqR1MNAlqRAGuiQVwkCXpEIY6JJUCANdkgphoEtSIf4fBNAci7XdwoEAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "import copy\n", + "from matplotlib import pyplot as plt\n", + "a = copy.deepcopy(model.tree.source) \n", + "plot_roc(a, X, y)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### (c) Partial AUC convex hull objective" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loss function: pauc\n", + "lambda: 0.01\n", + "COUNT_UNIQLEAVES: 4968\n", + "COUNT_LEAFLOOKUPS: 159398\n", + "total time: 60.001120805740356\n", + "leaves: [(1,), (-9, -1, 2), (-1, 2, 9), (-14, -12, -2, -1), (-14, -2, -1, 12), (-13, -2, -1, 14), (-2, -1, 13, 14)]\n", + "num_captured: [29, 31, 8, 11, 13, 20, 12]\n", + "prediction: [1, 1, 1, 1, 1, 1, 1]\n", + "Objective: 0.8700000000000001\n", + "pauc : 0.19999999999999996\n", + "COUNT of the best tree: 154031\n", + "time when the best tree is achieved: 40.786935567855835\n", + "TOTAL COUNT: 240889\n", + "{'feature': 11, 'name': 'jacket_color_red', 'reference': 1, 'relation': '==', 'true': {'complexity': 0.01, 'loss': 0.0, 'name': 'class', 'prediction': 1}, 'false': {'feature': 0, 'name': 'head_shape_round', 'reference': 1, 'relation': '==', 'true': {'feature': 3, 'name': 'body_shape_round', 'reference': 1, 'relation': '==', 'true': {'complexity': 0.01, 'loss': 0.0, 'name': 'class', 'prediction': 1}, 'false': {'complexity': 0.01, 'loss': 0.24999999999999983, 'name': 'class', 'prediction': 1}}, 'false': {'feature': 1, 'name': 'head_shape_square', 'reference': 1, 'relation': '==', 'true': {'feature': 4, 'name': 'body_shape_square', 'reference': 1, 'relation': '==', 'true': {'complexity': 0.01, 'loss': 0.0, 'name': 'class', 'prediction': 1}, 'false': {'complexity': 0.01, 'loss': 0.16129032258064507, 'name': 'class', 'prediction': 1}}, 'false': {'feature': 5, 'name': 'body_shape_octagon', 'reference': 1, 'relation': '==', 'true': {'complexity': 0.01, 'loss': 0.0, 'name': 'class', 'prediction': 1}, 'false': {'complexity': 0.01, 'loss': 0.08870967741935482, 'name': 'class', 'prediction': 1}}}}}\n" + ] + } + ], + "source": [ + "config={\n", + " \"regularization\": 0.01,\n", + " \"objective\": \"pauc\",\n", + " \"theta\": 0.2,\n", + " \"time_limit\": 60\n", + "}\n", + "model = GOSDT(config)\n", + "model.fit(X, y)\n", + "print(model.tree.source)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The algorithm cannot finish running within 60 seconds time limit and returns the best tree so far. We can also draw the ROC curve. You may find that this pauc convex hull value different from output above. This is because we calculate the whole area under the ROC curve here, but previous is only the area for the top theta proportion. " + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "area under roc curve: 1.0\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAQcklEQVR4nO3dYWhd533H8e8/sbMilmpj1mDElpQyB+rFgxSRZRTWFJXhBGy/6YqNQtcRKtotHaFlkKGRtil6sZatpsxbK7bStahJ0zJau7gE5iV0lLqLQtq4TvDwXNsRKY3adXoj2jjsvxf3RrtSrnzPjY/u1X30/YDweZ7z6Jz/43P18/E59+pEZiJJGnw39LsASVI9DHRJKoSBLkmFMNAlqRAGuiQVYke/drxr164cHx/v1+4laSA988wzP83MkXbr+hbo4+PjLCws9Gv3kjSQIuLyRuu85CJJhTDQJakQBrokFcJAl6RCGOiSVIiOgR4Rn4+IlyPihxusj4j4TERciIjnIuJt9ZfZMH92nvFj49zw8RsYPzbO/Nn5zdqVtiFfX1tHqcdis+dV5W2LXwD+DvjiBuvvAfY2v34P+Ifmn7WaPzvP9MlpVq6uAHB5+TLTJ6cBmNo/VffutM34+to6Sj0WvZhXVPn1uRExDnwzM29vs+5zwFOZ+WizfR64OzN/fK1tTkxMZDfvQx8/Ns7l5Q3ffilJA2lseIxLD16qPD4insnMiXbr6riGfgvwYkt7sdnXrpDpiFiIiIWlpaWudnJl+cobr1CStqg6s62OT4pGm762p/2ZOQfMQeMMvZudjA6Ptj1D7/ZfN6mdjf4H6Our90o9FhvNa3R4tLZ91HGGvgjsaWnvBl6qYbtrzE7OMrRzaE3f0M4hZidn696VtiFfX1tHqceiF/OqI9BPAO9tvtvlLmC50/XzN2Jq/xRzB+dW22PDY8wdnBvomyTaOl57fY0NjxGEr68+KvVY9GJeHW+KRsSjwN3ALuAnwEeBnQCZ+dmICBrvgjkArAB/kpkd73Z2e1N0tZ6PN67w5Ed9Fqqk7edaN0U7XkPPzKMd1ifwZ2+wNklSTfykqCQVwkCXpEIY6JJUCANdkgphoEtSIQx0SSqEgS5JhTDQJakQBrokFcJAl6RCGOiSVAgDXZIKYaBLUiEMdEkqhIEuSYUw0CWpEAa6JBXCQJekQhjoklQIA12SCmGgS1IhDHRJKoSBLkmFMNAlqRAGuiQVwkCXpEIY6JJUCANdkgphoEtSIQx0SSqEgS5JhagU6BFxICLOR8SFiHiozfrRiHgyIp6NiOci4t76S5UkXUvHQI+IG4HjwD3APuBoROxbN+yvgMcz8w7gCPD3dRcqSbq2KmfodwIXMvNiZr4CPAYcXjcmgTc3l4eBl+or8f/Nn51fXR4/Nr6mLUnbXZVAvwV4saW92Oxr9THgvohYBE4BH2q3oYiYjoiFiFhYWlrqqtD5s/NMn5xebV9evsz0yWlDXZKaqgR6tOnLde2jwBcyczdwL/CliHjdtjNzLjMnMnNiZGSkq0JnTs+wcnVlTd/K1RVmTs90tR1JKlWVQF8E9rS0d/P6Syr3A48DZOZ3gTcBu+oo8DVXlq901S9J202VQH8a2BsRt0bETTRuep5YN+YKMAkQEW+lEejdXVPpYHR4tKt+SdpuOgZ6Zr4KPAA8AbxA490s5yLikYg41Bz2EeD9EfED4FHgfZm5/rLMdZmdnGVo59CavqGdQ8xOzta5G0kaWDuqDMrMUzRudrb2Pdyy/Dzw9npLW2tq/xQA9/3LfQCMDY8xOzm72i9J291AfVK0NbwvPXjJMJekFgMV6JKkjRnoklQIA12SCmGgS1IhDHRJKoSBLkmFMNAlqRAGuiQVwkCXpEIY6JJUCANdkgphoEtSIQx0SSqEgS5JhTDQJakQBrokFcJAl6RCGOiSVAgDXZIKYaBLUiEMdEkqhIEuSYUw0CWpEAa6JBXCQJekQhjoklQIA12SCmGgS1IhDHRJKkSlQI+IAxFxPiIuRMRDG4x5T0Q8HxHnIuLL9ZYpSeqkY6BHxI3AceAeYB9wNCL2rRuzF/hL4O2Z+TvAg5tQK/Nn51eXx4+Nr2lL0nZX5Qz9TuBCZl7MzFeAx4DD68a8HziemT8HyMyX6y2zEebTJ6dX25eXLzN9ctpQl6SmKoF+C/BiS3ux2dfqNuC2iPhORJyJiAPtNhQR0xGxEBELS0tLXRU6c3qGlasra/pWrq4wc3qmq+1IUqmqBHq06ct17R3AXuBu4CjwjxHxa6/7psy5zJzIzImRkZGuCr2yfKWrfknabqoE+iKwp6W9G3ipzZhvZObVzPwRcJ5GwNdmdHi0q35J2m6qBPrTwN6IuDUibgKOACfWjfk68E6AiNhF4xLMxToLnZ2cZWjn0Jq+oZ1DzE7O1rkbSRpYHQM9M18FHgCeAF4AHs/McxHxSEQcag57AvhZRDwPPAn8RWb+rM5Cp/ZPMXdwbrU9NjzG3ME5pvZP1bkbSRpYkbn+cnhvTExM5MLCQtffFx9vXNLPj/anbknqp4h4JjMn2q3zk6KSVAgDXZIKYaBLUiEMdEkqhIEuSYUw0CWpEAa6JBXCQJekQhjoklQIA12SCmGgS1IhDHRJKoSBLkmFMNAlqRAGuiQVwkCXpEIY6JJUiIEK9Pmz86vL48fG17QlabsbmECfPzvP9Mnp1fbl5ctMn5w21CWpaWACfeb0DCtXV9b0rVxdYeb0TJ8qkqStZWAC/cryla76JWm7GZhAHx0e7apfkrabgQn02clZhnYOrekb2jnE7ORsnyqSpK1lYAJ9av8UcwfnVttjw2PMHZxjav9UH6uSpK1jYAIdWBPelx68ZJhLUouBCnRJ0sYMdEkqhIEuSYUw0CWpEAa6JBXCQJekQlQK9Ig4EBHnI+JCRDx0jXHvjoiMiIn6SpQkVdEx0CPiRuA4cA+wDzgaEfvajLsZ+HPge3UXKUnqrMoZ+p3Ahcy8mJmvAI8Bh9uM+wTwSeAXNdYnSaqoSqDfArzY0l5s9q2KiDuAPZn5zWttKCKmI2IhIhaWlpa6LlaStLEqgR5t+nJ1ZcQNwKeBj3TaUGbOZeZEZk6MjIxUr1KS1FGVQF8E9rS0dwMvtbRvBm4HnoqIS8BdwAlvjEpSb1UJ9KeBvRFxa0TcBBwBTry2MjOXM3NXZo5n5jhwBjiUmQubUrEkqa2OgZ6ZrwIPAE8ALwCPZ+a5iHgkIg5tdoGSpGp2VBmUmaeAU+v6Ht5g7N3XX5YkqVt+UlSSCmGgS1IhDHRJKoSBLkmFMNAlqRAGuiQVwkCXpEIY6JJUCANdkgphoEtSIQx0SSqEgS5JhTDQJakQBrokFcJAl6RCGOiSVAgDXZIKYaBLUiEMdEkqhIEuSYUw0CWpEAa6JBXCQJekQhjoklQIA12SCmGgS1IhDHRJKoSBLkmFMNAlqRAGuiQVolKgR8SBiDgfERci4qE26z8cEc9HxHMRcToixuovVZJ0LR0DPSJuBI4D9wD7gKMRsW/dsGeBicz8XeBrwCfrLlSSdG1VztDvBC5k5sXMfAV4DDjcOiAzn8zMlWbzDLC73jIlSZ1UCfRbgBdb2ovNvo3cD3yr3YqImI6IhYhYWFpaql6lJKmjKoEebfqy7cCI+4AJ4FPt1mfmXGZOZObEyMhI9SolSR3tqDBmEdjT0t4NvLR+UES8C5gB3pGZv6ynPElSVVXO0J8G9kbErRFxE3AEONE6ICLuAD4HHMrMl+svU5LUScdAz8xXgQeAJ4AXgMcz81xEPBIRh5rDPgX8KvDViPh+RJzYYHOSpE1S5ZILmXkKOLWu7+GW5XfVXJckqUt+UlSSCmGgS1IhDHRJKoSBLkmFMNAlqRAGuiQVwkCXpEIY6JJUCANdkgphoEtSIQx0SSqEgS5JhTDQJakQBrokFcJAl6RCGOiSVAgDXZIKYaBLUiEMdEkqhIEuSYUw0CWpEAa6JBXCQJekQhjoklQIA12SCmGgS1IhDHRJKoSBLkmFMNAlqRAGuiQVwkCXpEJUCvSIOBAR5yPiQkQ81Gb9r0TEV5rrvxcR43UXCjB/dn51efzY+Jq2JG13HQM9Im4EjgP3APuAoxGxb92w+4GfZ+ZvA58G/rruQufPzjN9cnq1fXn5MtMnpw11SWqqcoZ+J3AhMy9m5ivAY8DhdWMOA//cXP4aMBkRUV+ZMHN6hpWrK2v6Vq6uMHN6ps7dSNLAqhLotwAvtrQXm31tx2Tmq8Ay8BvrNxQR0xGxEBELS0tLXRV6ZflKV/2StN1UCfR2Z9r5BsaQmXOZOZGZEyMjI1XqWzU6PNpVvyRtN1UCfRHY09LeDby00ZiI2AEMA/9dR4GvmZ2cZWjn0Jq+oZ1DzE7O1rkbSRpYVQL9aWBvRNwaETcBR4AT68acAP64ufxu4N8y83Vn6Ndjav8UcwfnGBseIwjGhseYOzjH1P6pOncjSQNrR6cBmflqRDwAPAHcCHw+M89FxCPAQmaeAP4J+FJEXKBxZn5kM4qd2j9lgEvSBjoGOkBmngJOret7uGX5F8Af1VuaJKkbflJUkgphoEtSIQx0SSqEgS5JhYia311YfccRS8DlN/jtu4Cf1ljOIHDO24Nz3h6uZ85jmdn2k5l9C/TrERELmTnR7zp6yTlvD855e9isOXvJRZIKYaBLUiEGNdDn+l1AHzjn7cE5bw+bMueBvIYuSXq9QT1DlyStY6BLUiG2dKBvlYdT91KFOX84Ip6PiOci4nREjPWjzjp1mnPLuHdHREbEwL/FrcqcI+I9zWN9LiK+3Osa61bhtT0aEU9GxLPN1/e9/aizLhHx+Yh4OSJ+uMH6iIjPNP8+nouIt133TjNzS37R+FW9/wW8BbgJ+AGwb92YPwU+21w+Anyl33X3YM7vBIaayx/cDnNujrsZ+DZwBpjod909OM57gWeBX2+2f7PfdfdgznPAB5vL+4BL/a77Ouf8B8DbgB9usP5e4Fs0nvh2F/C9693nVj5D3xIPp+6xjnPOzCcz87WnZZ+h8QSpQVblOAN8Avgk8IteFrdJqsz5/cDxzPw5QGa+3OMa61Zlzgm8ubk8zOufjDZQMvPbXPvJbYeBL2bDGeDXIuK3rmefWznQa3s49QCpMudW99P4F36QdZxzRNwB7MnMb/aysE1U5TjfBtwWEd+JiDMRcaBn1W2OKnP+GHBfRCzSeP7Ch3pTWt90+/PeUaUHXPRJbQ+nHiCV5xMR9wETwDs2taLNd805R8QNwKeB9/WqoB6ocpx30LjscjeN/4X9e0Tcnpn/s8m1bZYqcz4KfCEz/yYifp/GU9Buz8z/3fzy+qL2/NrKZ+hb4uHUPVZlzkTEu4AZ4FBm/rJHtW2WTnO+GbgdeCoiLtG41nhiwG+MVn1tfyMzr2bmj4DzNAJ+UFWZ8/3A4wCZ+V3gTTR+iVWpKv28d2MrB/qWeDh1j3Wcc/Pyw+dohPmgX1eFDnPOzOXM3JWZ45k5TuO+waHMXOhPubWo8tr+Oo0b4ETELhqXYC72tMp6VZnzFWASICLeSiPQl3paZW+dAN7bfLfLXcByZv74urbY7zvBHe4S3wv8J4274zPNvkdo/EBD44B/FbgA/Afwln7X3IM5/yvwE+D7za8T/a55s+e8buxTDPi7XCoe5wD+FngeOAsc6XfNPZjzPuA7NN4B833gD/td83XO91Hgx8BVGmfj9wMfAD7QcoyPN/8+ztbxuvaj/5JUiK18yUWS1AUDXZIKYaBLUiEMdEkqhIEuSYUw0CWpEAa6JBXi/wBu6hOKwQjr1AAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "a = copy.deepcopy(model.tree.source) \n", + "plot_roc(a, X, y)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Example on running on FourClass dataset\n", + "Monk1 is not very interesting. Different objectives end up with the same ROC curve. In the second example, we run on Four Class dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "df = pd.read_csv(\n", + " \"https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/fourclass\", sep=\"\\s+|:\", header=None, engine=\"python\")" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "X = df.iloc[:,[2,4]].astype(int)\n", + "X.columns = [\"f1\", \"f2\"]\n", + "y = df.iloc[:,0]\n", + "y = pd.DataFrame(y)\n", + "y[y==-1] = 0" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
f1f2
0167178
116231
214296
3132169
416974
\n", + "
" + ], + "text/plain": [ + " f1 f2\n", + "0 167 178\n", + "1 162 31\n", + "2 142 96\n", + "3 132 169\n", + "4 169 74" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "X.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "These two features are continuous. We use threshold guess to preprocess the data. " + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(862, 7)\n", + " f1<=70.5 f1<=73.5 f1<=114.5 f1<=120.5 f1<=143.5 f2<=82.5 f2<=167.5\n", + "0 0 0 0 0 0 0 0\n", + "1 0 0 0 0 0 1 1\n", + "2 0 0 0 0 1 0 1\n", + "3 0 0 0 0 1 0 0\n", + "4 0 0 0 0 0 1 1\n" + ] + } + ], + "source": [ + "# GBDT parameters for threshold and lower bound guesses\n", + "n_est = 40\n", + "max_depth = 1\n", + "# guess thresholds\n", + "X_guessed, thresholds, header, threshold_guess_time = compute_thresholds(X.copy(), y, n_est, max_depth)\n", + "print(X_guessed.shape)\n", + "print(X_guessed.head())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### (a) F1 objective" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "config={\n", + " \"regularization\": 0.01,\n", + " \"objective\": \"f1\",\n", + " \"w\": 0.9,\n", + " \"time_limit\": 60\n", + "}\n", + "model = GOSDT(config)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loss function: f1\n", + "lambda: 0.01\n", + "COUNT_UNIQLEAVES: 80\n", + "COUNT_LEAFLOOKUPS: 14\n", + "total time: 0.24343013763427734\n", + "leaves: [(1,), (-1, 6), (-7, -6, -1), (-6, -4, -1, 7), (-6, -1, 4, 7)]\n", + "num_captured: [294, 230, 65, 138, 135]\n", + "prediction: [0, 1, 0, 1, 0]\n", + "Objective: 0.29444444444444445\n", + "f1 : 0.7555555555555555\n", + "COUNT of the best tree: 116\n", + "time when the best tree is achieved: 0.2224864959716797\n", + "TOTAL COUNT: 133\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.fit(X_guessed, y)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model training time: 0.24345089999999914\n", + "Training accuracy: 0.808584686774942\n", + "# of leaves: 5\n" + ] + } + ], + "source": [ + "acc = model.score(X_guessed, y)\n", + "n_leaves = model.leaves()\n", + "n_nodes = model.nodes()\n", + "\n", + "print(\"Model training time: {}\".format(model.duration))\n", + "print(\"Training accuracy: {}\".format(acc))\n", + "print(\"# of leaves: {}\".format(n_leaves))" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "f1 score: 0.7555555555555555\n" + ] + } + ], + "source": [ + "# Note that if you want to test f1 metric, tree.score() won't work. \n", + "from sklearn.metrics import f1_score\n", + "y_hat = model.predict(X_guessed)\n", + "print(\"f1 score:\", f1_score(y, y_hat)) # the value matches the printout above. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### (b) AUC convex hull objective" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loss function: auc\n", + "lambda: 0.01\n", + "COUNT_UNIQLEAVES: 655\n", + "COUNT_LEAFLOOKUPS: 112127\n", + "total time: 60.00080442428589\n", + "leaves: [(2,), (-4, -2, 6), (-6, -2, 4), (-2, 4, 6), (-7, -6, -4, -2), (-6, -4, -2, 7)]\n", + "num_captured: [278, 114, 153, 123, 56, 138]\n", + "prediction: [1, 1, 1, 1, 1, 1]\n", + "Objective: 0.16208645127211901\n", + "auc : 0.897913548727881\n", + "COUNT of the best tree: 4523\n", + "time when the best tree is achieved: 3.0249080657958984\n", + "TOTAL COUNT: 95088\n", + "{'feature': 0, 'name': 'f1<=70.5', 'reference': 1, 'relation': '==', 'true': {'complexity': 0.01, 'loss': 0.31322505800464034, 'name': 'class', 'prediction': 1}, 'false': {'feature': 3, 'name': 'f1<=120.5', 'reference': 1, 'relation': '==', 'true': {'feature': 5, 'name': 'f2<=82.5', 'reference': 1, 'relation': '==', 'true': {'complexity': 0.01, 'loss': 0.07540603248259867, 'name': 'class', 'prediction': 1}, 'false': {'complexity': 0.01, 'loss': 0.1299303944315549, 'name': 'class', 'prediction': 1}}, 'false': {'feature': 5, 'name': 'f2<=82.5', 'reference': 1, 'relation': '==', 'true': {'complexity': 0.01, 'loss': 0.002320185614849188, 'name': 'class', 'prediction': 1}, 'false': {'feature': 6, 'name': 'f2<=167.5', 'reference': 1, 'relation': '==', 'true': {'complexity': 0.01, 'loss': 0.058004640371229696, 'name': 'class', 'prediction': 1}, 'false': {'complexity': 0.01, 'loss': 0.06496519721577726, 'name': 'class', 'prediction': 1}}}}}\n" + ] + } + ], + "source": [ + "config={\n", + " \"regularization\": 0.01,\n", + " \"objective\": \"auc\",\n", + " \"time_limit\": 60\n", + "}\n", + "model = GOSDT(config)\n", + "model.fit(X_guessed, y)\n", + "print(model.tree.source)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "area under roc curve: 0.897913548727881\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAcF0lEQVR4nO3de3RU9b338fc3Fy5RLkWiPQokVtEDBRWJQEVrXdiCWqCe4wUatX1EI1ZBavu4sHF5+qj01PZYbFqqpD2eWk25WG/RUq1FPaLFQJAqgtKTUgkRkaAWLwGSkO/zxyQ5IZmQCczMzuz5vNZiMbNnM/vzY8LH7W/vPdvcHRERSX0ZQQcQEZH4UKGLiISECl1EJCRU6CIiIaFCFxEJiaygNjx48GDPz88PavMiIilp3bp1u9w9N9prgRV6fn4+lZWVQW1eRCQlmdnWzl7TlIuISEio0EVEQkKFLiISEip0EZGQUKGLiIREl4VuZveb2U4ze6OT183MSsysysxeN7PT4x9TRCQ+yjaUkX9PPhn/L4P8e/Ip21AWmm3Hsof+a2DKQV4/Hxje/KsIuPfwY4mIxF/ZhjKKnixi6+6tOM7W3VsperIoKaWejG1bLF+fa2b5wFPuPirKa4uBF9x9SfPzzcCX3P3dg71nQUGB6zx0EelKY1Mjexr2sKdxT4ff9zbu7fS1tr/vbdzLnsY9PPrmo+xp3NNhG70zezPuuHEJHcead9awb/++DsvzBuTx9ry3Y34fM1vn7gXRXovHhUXHAdvaPK9pXtah0M2siMhePMOGDYvDpkUkmZq8KWpRdlWqUYu3i/JtWdbY1Jjwce3bv49V1asSvp1oqndXx+294lHoFmVZ1N1+dy8FSiGyhx6HbYukLXdn3/593dpD7XSdGPd66/fXJ32cGZZB36y+9M3uS9+svvTJ6tP6uMPv7Zb1yepzwLKb/ngTu+p2ddjGMUccw/JLlid0HJc+fCnvffpeh+XDBsRv5zYehV4DDG3zfAiwPQ7vKwJE5h6LVxZTvbuaYQOGsWDSAgpHFwYd6wDuTkNTw2FPDbSuG+N6QWhfkgcr14OWb7vfO1s3OyMbs2j7jd2XkZFB0ZNF1DXUtS7Lyc7h7sl388W8L8ZlG525e/LdUbe9YNKCuG0jHoVeDtxgZkuB8cDurubPRWLVciCp5R9By4Ek4KCl3n7eNdaSjLpujOs1eVNS/k7a6pXZK/aS7KJUYynq3pm941auQWj5mQliByEZ2+7yoKiZLQG+BAwG3gP+DcgGcPf7LPLp/pzImTB1wP9x9y6PduqgqMQi/558tu7u+F1EvTN7M/bYsZ0WczLmXdvLysiKvSQPoUzbr9snqw+ZGZlJH6cE67AOirr7zC5ed+D6Q8wmEpW7s6p6VdQyh8hBrD9v+3Onf96wmP9XP15TA1kZgX15qQgQ4NfnikSzp2EPS95YQklFCa+991qn6x19xNE8cukjnRZvPOddRVKFCl16hG27t3Fv5b2Urivl/T3vA5HSnjhkIk//7ekDzh3Oyc7hJ5N/wlnDzgoqrkiPpEKXwLg7L297mZKKEh5981H2+34Axv7TWG4cfyOXfv5Semf1TomzXER6gpiuFE0EHRRNX3sb97L0jaWUVJSwfsd6IHJA8eKRFzN33FwmDJmg6RKRTiT6SlGRmLzz0TvcW3kvi9ctbr24Izcnl2vHXsvsgtkc1/+4gBOKpDYVuiSUu7O6ZjUlFSU88uYjracTjvnsGG4cfyOXjbqMPll9Ak4pEg4qdEmIfY37WLZxGSUVJax7dx0AmZbJpZ+/lLnj5nLm0DM1rSISZyp0iavtH2/nvsr7WLxuMTs/3QnAUX2P4tqx13LdGdcxpP+QgBOKhJcKXQ6bu1PxTgUlFSU8vOnh1mmVU485lRvH38iMUTPom9034JQi4adCl0O2r3EfD296mJKKEtZuXwtEvhmv5WyVs4adpWkVkSRSoUu37fhkB/dV3sd9lfe1fh3ooL6DKDq9iOvOuC6uXwcqIrFToUvM1ryzhpKKEpZvXE5DUwMAo48ezY3jb+Tro7+uaRWRgKnQ5aDq99fzu02/o6SihIp3KoDItMq/jPgX5oybwzl552haRaSHUKFLVO998h6L1y3mvsr7ePeTyNfbD+wzkGtOv4ZvnfEt8gfmBxtQRDpQocsBKrdXUlJRwrKNy1pvN/b53M8zd/xcCkcXckSvIwJOKCKdUaELDfsbePTNR/lpxU9ZXbMaiHyf+PSTpzN3/FzOzT9X0yoiKUCFnsZqP62ldF0pv6j8Bds/jtwGdkDvAVx9+tVcf8b1HP+Z4wNOKCLdoUJPQ+vfXU/JmhKWbFjCvv37ABgxeARzx8/l8lMu58heRwacUEQOhQo9TTTsb+Dxtx6nZE0JL1W/BESmVaaeNJW54+cy6fhJmlYRSXEq9JDbVbeLX677Jb+o/AU1H9UA0L93f2aNmcX1Z1zPCYNOCDihiMSLCj2k/rLjL/ys4meUbShrnVY5+aiTmTt+LleeeqWmVURCSIUeIo1NjTzx1hOUrCnhxa0vti6/cPiFzB0/l/M+dx4ZlhFgQhFJJBV6CLxf9z6/evVXLFq7iG0fbQOgX69+XDXmKq4/43qGHzU84IQikgwq9BTS/mbJ1xZcy5YPtvDQhofY27gXgOGDhjN3/Fy+ceo36Ne7X8CJRSSZVOgpomxDGUVPFlHXUAfA1t1b+d7K77W+PuXEKcwdN5fJJ07WtIpImlKhp4jilcWtZd5Wv179WHvNWk4efHIAqUSkJ9GuXIqo3l0ddfkn9Z+ozEUEUKGnBHfv9DRD3UxCRFqo0FPAv73wb3xc/3GH5TnZOSyYtCCARCLSE6nQe7i7XrqLO168g0zLZN74eeQNyMMw8gbkUTq1lMLRhUFHFJEeQgdFe7Cfr/k581fOxzAe+NoDFJ5SyMIpC4OOJSI9VEx76GY2xcw2m1mVmc2P8vowM3vezNab2etmdkH8o6aX/1r/X8z5wxwAFn91MYWnaE9cRA6uy0I3s0xgEXA+MBKYaWYj2612K7Dc3ccAM4BfxDtoOln2xjKufvJqABZOXsg1Y68JOJGIpIJY9tDHAVXuvsXd64GlwPR26zjQv/nxAGB7/CKml/LN5Vz+2OU0eRN3nHsH8ybMCzqSiKSIWAr9OGBbm+c1zcva+j5wuZnVACuAOdHeyMyKzKzSzCpra2sPIW64Pfu3Z7nk4UtobGpk/sT5FJ9dHHQkEUkhsRR6tLseeLvnM4Ffu/sQ4ALgQbOO15+7e6m7F7h7QW5ubvfThtiqrauYvnQ69fvrmTNuDj+Y9APdcEJEuiWWQq8BhrZ5PoSOUyqzgOUA7r4a6AMMjkfAdFC5vZILf3shexr3cNVpV3HPlHtU5iLSbbEU+lpguJkdb2a9iBz0LG+3TjUwCcDMRhApdM2pxGDDexuY/NBkPq7/mMs+fxmlU0v15Voicki6bA53bwRuAJ4B3iRyNstGM7vdzKY1r/Yd4Bozew1YAnzT3dtPy0g7f33/r5z34Hl8sOcDpp08jQcvepDMjMygY4lIiorpwiJ3X0HkYGfbZbe1ebwJmBjfaOH29j/eZtJvJrHz0518+XNfZtnFy8jOzA46loikMP2/fQDe+egdJv1mEjUf1XDWsLN47LLH6JPVJ+hYIpLiVOhJVvtpLec9eB5bPtxCwbEF/P7rv+eIXkcEHUtEQkCFnkQf7vmQLz/4Zd7a9Rajjx7NM5c/Q//e/bv+gyIiMVChJ8nH+z7m/LLzee291zjpqJN49opnGdR3UNCxRCREVOhJUNdQx9QlU6l4p4L8gfmsvHIlxxx5TNCxRCRkVOgJtq9xH/+6/F/5763/zbH9juVPV/yJIf2HBB1LREJIhZ5AjU2NzHxkJk9XPU1uTi5/uuJPnDDohKBjiUhIqdATpMmb+Obj3+Sxtx5jYJ+B/PGKPzIid0TQsUQkxFToCeDuXPfUdZRtKOPIXkfydOHTnPbZ04KOJSIhp0KPM3fnpmduovTVUvpk9eGpmU8xfsj4oGOJSBpQocfZbc/fxj0V95Cdkc1jlz3GOfnnBB1JRNKECj2OfvjSD7lz1Z1kWibLLl7GlBOnBB1JRNKICj1Ofr7m59yy8hYM44GvPcBFIy4KOpKIpBkVehzcv/5+5vwhcte9xV9dTOEphQEnEpF0pEI/TEvfWMrV5VcDsHDyQq4Ze03AiUQkXanQD0P55nKueOwKHOfOc+9k3oR5QUcSkTSmQj9Ez/7tWS55+BIamxqZP3E+3zv7e0FHEpE0p0I/BKu2rmL60unU769nzrg5/GDSD3RTZxEJnAq9m9a+s5YLf3shexr3cNVpV3HPlHtU5iLSI6jQu+H1915n8kOT+bj+Y2aMmkHp1FIyTH+FItIzxHST6HRWtqGM4pXFVO+uxsxo8iamnTyN33ztN2RmZAYdT0SklQr9IMo2lFH0ZBF1DXVA5HtaMiyDi/75IrIzswNOJyJyIM0XHETxyuLWMm/R5E18/4XvBxNIROQgVOgHUb27ulvLRUSCpEI/iM5uFTdswLAkJxER6ZoK/SDOHnZ2h2U52TksmLQggDQiIgenQu/E3sa9PPf2cwDk5uRiGHkD8iidWkrhaH35loj0PDrLpRP3r7+fHZ/s4LTPnsarRa/q4iER6fG0hx5F/f567nr5LgBuPftWlbmIpAQVehQPvf4Q1burGTF4hG5UISIpI6ZCN7MpZrbZzKrMbH4n61xqZpvMbKOZ/Ta+MZNnf9N+/v2lfweg+OxiXdovIimjyzl0M8sEFgFfBmqAtWZW7u6b2qwzHLgFmOjuH5rZ0YkKnGjLNy6n6oMqTvjMCVw26rKg44iIxCyW3c9xQJW7b3H3emApML3dOtcAi9z9QwB33xnfmMnR5E0sWBU5JfGWs24hK0PHjEUkdcRS6McB29o8r2le1tZJwElm9rKZvWJmUW93b2ZFZlZpZpW1tbWHljiBnnjrCTbWbmRo/6FcceoVQccREemWWAo92ike3u55FjAc+BIwE/iVmQ3s8IfcS929wN0LcnNzu5s1odydO1fdCcDNE2+mV2avgBOJiHRPLIVeAwxt83wIsD3KOk+4e4O7/x3YTKTgU8bTVU/z6ruvcswRxzBrzKyg44iIdFsshb4WGG5mx5tZL2AGUN5unceBcwHMbDCRKZgt8QyaSO7OHS/eAcB3z/wufbP7BpxIRKT7uix0d28EbgCeAd4Elrv7RjO73cymNa/2DPC+mW0Cngf+r7u/n6jQ8fbC2y+wumY1g/oOYnbB7KDjiIgckphO43D3FcCKdstua/PYgZuaf6Wclrnzb0/4Nkf2OjLgNCIihybtr5pZvW01z/39Ofr37s8N424IOo6IyCFL+0JvOe98zrg5DOzT4cQcEZGUkdaFvv7d9fz+f35PTnYO8ybMCzqOiMhhSetCb9k7nz12NoNzBgecRkTk8KRtoW/cuZFH3nyE3pm9+c6Z3wk6jojIYUvbQm/5RsVZY2ZxbL9jA04jInL40rLQqz6oYskbS8jKyOLmiTcHHUdEJC7SstB/+NIPafImrjzlSvIG5gUdR0QkLtKu0Kt3V/PAaw+QYRnMPyvqvTpERFJS2hX6j1/+MY1NjcwYNYPhR6XU94eJiBxUWhX6jk928MtXfwlEbmAhIhImaVXod//5bvbt38dF/3wRo44eFXQcEZG4SptC31W3i3sr7wUiN38WEQmbtCn0n77yUz5t+JTzTzyfsceODTqOiEjcpUWh/2PvPyhZUwLArV+8NeA0IiKJkRaFvmjNIj7a9xHn5p/LmUPPDDqOiEhChL7QP6n/hIWvLAS0dy4i4Rb6Ql9cuZj397zPF4Z8gXPzzw06johIwoS60Pc27uU/Vv8HEDmzxcwCTiQikjihLvT719/Pjk92cNpnT+OC4RcEHUdEJKFCW+j1++u56+W7ALj17Fu1dy4ioRfaQn/o9Yeo3l3NiMEjuGjERUHHERFJuFAWemNTY+sNLIrPLibDQjlMEZEDhLLplm9cTtUHVZzwmRO4bNRlQccREUmK0BV6kze13vz5lrNuISsjK+BEIiLJEbpCf/ytx9lUu4mh/YdyxalXBB1HRCRpQlXo7s6dL94JwM0Tb6ZXZq+AE4mIJE+oCv3pqqdZv2M9xxxxDLPGzAo6johIUoWm0N2dO168A4Dvnvld+mb3DTiRiEhyhabQX3j7BVbXrGZQ30HMLpgddBwRkaQLTaHfuSoyd/7tCd/myF5HBpxGRCT5Yip0M5tiZpvNrMrM5h9kvYvNzM2sIH4Ru/bnbX/mub8/R//e/blh3A3J3LSISI/RZaGbWSawCDgfGAnMNLORUdbrB8wFKuIdsist553PGTeHgX0GJnvzIiI9Qix76OOAKnff4u71wFJgepT17gB+BOyNY74uvfruq6z4nxXkZOcwb8K8ZG5aRKRHiaXQjwO2tXle07yslZmNAYa6+1MHeyMzKzKzSjOrrK2t7XbYaFr2zmePnc3gnMFxeU8RkVQUS6FH+95Zb33RLANYCHynqzdy91J3L3D3gtzc3NhTdmLjzo08+uaj9M7szXfO7HLzIiKhFkuh1wBD2zwfAmxv87wfMAp4wczeBiYA5ck4MNryjYqzxszi2H7HJnpzIiI9WiyFvhYYbmbHm1kvYAZQ3vKiu+9298Hunu/u+cArwDR3r0xI4mZVH1Sx5I0lZGVkcfPEmxO5KRGRlNBlobt7I3AD8AzwJrDc3Tea2e1mNi3RAdsr21BG/j35DP/ZcJq8iYlDJpI3MC/ZMUREepyYvlvW3VcAK9otu62Tdb90+LGiK9tQRtGTRdQ11LUuq9heQdmGMgpHFyZqsyIiKSGlrhQtXll8QJkD7G3cS/HK4oASiYj0HClV6NW7q7u1XEQknaRUoQ8bMKxby0VE0klKFfqCSQvIyc45YFlOdg4LJi0IKJGISM+RUoVeOLqQ0qml5GRFSn1wzmBKp5bqgKiICClW6BAp9a+c+BUASr+qMhcRaZFyhQ7Q2NQIQFZGTGddioikBRW6iEhIqNBFREIipQs9MyMz4CQiIj1HShb6/qb9gPbQRUTaSslC15SLiEhHKnQRkZBQoYuIhIQKXUQkJFToIiIhoUIXEQkJFbqISEikdKFnmi4sEhFpkZKFvt91YZGISHspWeiachER6UiFLiISEip0EZGQUKGLiISECl1EJCRU6CIiIZFyhd7kTTR5EwAZlnLxRUQSJuUase3NLcws4DQiIj1HyhW6rhIVEYku5QpdV4mKiEQXU6Gb2RQz22xmVWY2P8rrN5nZJjN73cxWmlle/KNG6ICoiEh0XRa6mWUCi4DzgZHATDMb2W619UCBu58C/A74UbyDtlChi4hEF8se+jigyt23uHs9sBSY3nYFd3/e3euan74CDIlvzP+lQhcRiS6WQj8O2NbmeU3zss7MAv4Q7QUzKzKzSjOrrK2tjT1lGyp0EZHoYin0aOcGetQVzS4HCoAfR3vd3UvdvcDdC3Jzc2NP2YYKXUQkulhasQYY2ub5EGB7+5XM7DygGDjH3ffFJ15HKnQRkehi2UNfCww3s+PNrBcwAyhvu4KZjQEWA9PcfWf8Y/4vFbqISHRdFrq7NwI3AM8AbwLL3X2jmd1uZtOaV/sxcCTwsJn9xczKO3m7w6ZCFxGJLqZWdPcVwIp2y25r8/i8OOfqVOuVohm6UlREpK3Uu1K0SVeKiohEk3KFrikXEZHoVOgiIiGhQhcRCQkVuohISKjQRURCQoUuIhISKnQRkZBI2ULXLehERA6UsoWuPXQRkQOlXKHrnqIiItGlXKFrD11EJDoVuohISKjQRURCQoUuIhISKnQRkZBQoYuIhIQKXUQkJFK20HWlqIjIgVK20LWHLiJyoJQrdN1TVEQkupQrdO2hi4hEp0IXEQkJFbqISEio0EVEQkKFLiISEip0EZGQUKGLiIRE6hW6N18pmqErRUVE2kq5QteFRSIi0aVcoWvKRUQkupgK3cymmNlmM6sys/lRXu9tZsuaX68ws/x4BwUo21BG+eZyAOb8YQ5lG8oSsRkRkZTUZaGbWSawCDgfGAnMNLOR7VabBXzo7icCC4G74h20bEMZRU8WsadxDwC76nZR9GSRSl1EpFkse+jjgCp33+Lu9cBSYHq7daYDDzQ//h0wycwsfjGheGUxdQ11Byyra6ijeGVxPDcjIpKyYin044BtbZ7XNC+Luo67NwK7gaPav5GZFZlZpZlV1tbWdito9e7qbi0XEUk3sRR6tD1tP4R1cPdSdy9w94Lc3NxY8rUaNmBYt5aLiKSbWAq9Bhja5vkQYHtn65hZFjAA+CAeAVssmLSAnOycA5blZOewYNKCeG5GRCRlxVLoa4HhZna8mfUCZgDl7dYpB77R/Phi4Dl377CHfjgKRxdSOrWUvAF5GEbegDxKp5ZSOLownpsREUlZXZ7M7e6NZnYD8AyQCdzv7hvN7Hag0t3Lgf8EHjSzKiJ75jMSEbZwdKEKXESkEzFdnePuK4AV7Zbd1ubxXuCS+EYTEZHuSLkrRUVEJDoVuohISKjQRURCQoUuIhISFuezC2PfsFktsPUQ//hgYFcc46QCjTk9aMzp4XDGnOfuUa/MDKzQD4eZVbp7QdA5kkljTg8ac3pI1Jg15SIiEhIqdBGRkEjVQi8NOkAANOb0oDGnh4SMOSXn0EVEpKNU3UMXEZF2VOgiIiHRowu9p9ycOpliGPNNZrbJzF43s5VmlhdEznjqasxt1rvYzNzMUv4Ut1jGbGaXNn/WG83st8nOGG8x/GwPM7PnzWx988/3BUHkjBczu9/MdprZG528bmZW0vz38bqZnX7YG3X3HvmLyFf1/g34HNALeA0Y2W6dbwH3NT+eASwLOncSxnwukNP8+Lp0GHPzev2AF4FXgIKgcyfhcx4OrAc+0/z86KBzJ2HMpcB1zY9HAm8Hnfswx/xF4HTgjU5evwD4A5E7vk0AKg53mz15D71H3Jw6ybocs7s/7+4td8t+hcgdpFJZLJ8zwB3Aj4C9yQyXILGM+Rpgkbt/CODuO5OcMd5iGbMD/ZsfD6DjndFSiru/yMHv3DYd+I1HvAIMNLN/Opxt9uRCj9vNqVNILGNuaxaR/8Knsi7HbGZjgKHu/lQygyVQLJ/zScBJZvaymb1iZlOSli4xYhnz94HLzayGyP0X5iQnWmC6+++9SzHd4CIgcbs5dQqJeTxmdjlQAJyT0ESJd9Axm1kGsBD4ZrICJUEsn3MWkWmXLxH5v7BVZjbK3f+R4GyJEsuYZwK/dve7zewLRO6CNsrdmxIfLxBx76+evIfeI25OnWSxjBkzOw8oBqa5+74kZUuUrsbcDxgFvGBmbxOZayxP8QOjsf5sP+HuDe7+d2AzkYJPVbGMeRawHMDdVwN9iHyJVVjF9O+9O3pyofeIm1MnWZdjbp5+WEykzFN9XhW6GLO773b3we6e7+75RI4bTHP3ymDixkUsP9uPEzkAjpkNJjIFsyWpKeMrljFXA5MAzGwEkUKvTWrK5CoHrmw+22UCsNvd3z2sdwz6SHAXR4kvAP5K5Oh4cfOy24n8g4bIB/4wUAWsAT4XdOYkjPlPwHvAX5p/lQedOdFjbrfuC6T4WS4xfs4G/ATYBGwAZgSdOQljHgm8TOQMmL8AXwk682GOdwnwLtBAZG98FjAbmN3mM17U/PexIR4/17r0X0QkJHrylIuIiHSDCl1EJCRU6CIiIaFCFxEJCRW6iEhIqNBFREJChS4iEhL/H6eE+pxJbwQQAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "a = copy.deepcopy(model.tree.source) \n", + "plot_roc(a, X_guessed, y)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### (c) Partial AUC convex hull objective" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loss function: pauc\n", + "lambda: 0.01\n", + "COUNT_UNIQLEAVES: 829\n", + "COUNT_LEAFLOOKUPS: 229855\n", + "total time: 60.00115752220154\n", + "leaves: [(4,), (-4, 6), (-7, -6, -4), (-6, -4, 7)]\n", + "num_captured: [554, 114, 56, 138]\n", + "prediction: [1, 1, 1, 1]\n", + "Objective: 0.9218638392234353\n", + "pauc : 0.11813616077656475\n", + "COUNT of the best tree: 36\n", + "time when the best tree is achieved: 0.02991962432861328\n", + "TOTAL COUNT: 115120\n", + "{'feature': 3, 'name': 'f1<=120.5', 'reference': 1, 'relation': '==', 'true': {'complexity': 0.01, 'loss': 0.5185614849187898, 'name': 'class', 'prediction': 1}, 'false': {'feature': 5, 'name': 'f2<=82.5', 'reference': 1, 'relation': '==', 'true': {'complexity': 0.01, 'loss': 0.002320185614849188, 'name': 'class', 'prediction': 1}, 'false': {'feature': 6, 'name': 'f2<=167.5', 'reference': 1, 'relation': '==', 'true': {'complexity': 0.01, 'loss': 0.058004640371229696, 'name': 'class', 'prediction': 1}, 'false': {'complexity': 0.01, 'loss': 0.06496519721577726, 'name': 'class', 'prediction': 1}}}}\n" + ] + } + ], + "source": [ + "config={\n", + " \"regularization\": 0.01,\n", + " \"objective\": \"pauc\",\n", + " \"theta\": 0.2,\n", + " \"time_limit\": 60\n", + "}\n", + "model = GOSDT(config)\n", + "model.fit(X_guessed, y)\n", + "print(model.tree.source)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "area under roc curve: 0.8123866537547318\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAdUUlEQVR4nO3de3BV53nv8e+jGyBzRxgMQhIgnIDxXbGxCRgMNnKmsSedNMGHNKdnmGrcHLvjxD2JU2VyeuwwzaVuLj2eNKTHk7Shdp22k+AECzA32wRiZGyDwSWWQAiZm7hYgIUQkp7zxxbL2mKDtszWXtpLv8+MZvZe79Jez0LSz6/Xfp+9zN0REZHMlxV2ASIikhoKdBGRiFCgi4hEhAJdRCQiFOgiIhGRE9aBCwoKvKSkJKzDi4hkpNdff/2Yu49NNBZaoJeUlFBdXR3W4UVEMpKZ7b/UmC65iIhEhAJdRCQiFOgiIhGhQBcRiQgFuohIRPQY6Gb2jJkdNbO3LzFuZvYjM6sxsx1mdkvqyxQRSY0VO1dQ8oMSsv5PFiU/KGHFzhWROXYyM/SfAeWXGb8PmNb5VQH8+MrLEhFJvRU7V1DxQgX7m/bjOPub9lPxQkVaQj0dx7ZkPj7XzEqA37j7zARjPwE2uvuznc/3APPc/dDlXrOsrMy1Dl1E0qnw7wt57/R7F20flD2I2ybe1qfHfu291zjXfu6i7cUjiql7tC7p1zGz1929LNFYKhqLJgIHujxv6Nx2UaCbWQWxWTxFRUUpOLSIyKW1tLXwav2rVNVUUVVTlTDMAc61n+OV+lfSXF1MfVN9yl4rFYFuCbYlnPa7+3JgOcRm6Ck4tohInJoTNUGAb6jbQPP55mDMMDxBPI27ahzP/8nzfVrX5375OY58cOSi7UUjUje5TUWgNwCTujwvBA6m4HVFRHp0pvUMG/ZtiIV4bRV7T+6NG79p/E2UTy2nvLScuqY6vvTbL8WFfH5uPk8teoq5xXP7tM6nFj1FxQsVFx172YJlKTtGKgJ9JfCwmT0H3A409XT9XETko3J33j76dhDgr+x/hfMd54Px0UNGc+/UeymfWs69U+/lmmHXBGN3cRc5WTlUrqukvqmeohFFLFuwjCXXL+nzui8coy+P3eObomb2LDAPKACOAP8byAVw9380MwP+L7GVMM3A/3D3Ht/t1JuiIpKsk2dP8tLel4IQP3j6w4sAhnF74e3BLLxsQhnZWdkhVtu3ruhNUXd/sIdxB/7nR6xNROQiHd7B6wdfDwJ8a8NWOrwjGB8/dDzlpeWUTy1n4ZSFjMkfE2K1/UdoH58rItLVkTNHWFO7hqraKtbUruFY87FgLCcrh7nFc4NZ+A3jbiB2cUC6UqCLSCjOt59nS8MWVtespqq2iu2HtseNl4ws4b7S+ygvLWd+yXyGDRoWUqWZQ4EuImlT31QfLClct28dp86dCsYG5wxmfsn82KWU0nKmjZ6mWXgvKdBFpM+0tLXw8v6XgxB/59g7cePTC6YHAT6naA5DcoeEVGk0KNBFJGXcnXdPvBsE+Ma6jZxtOxuMD8sbxsIpCykvLWfR1EUUjywOsdroUaCLyBU5fe40G+o2BCG+7/19ceM3j785mIXfUXgHudm5IVUafQp0EekVd2fHkR1U1VSxunY1r9a/GtfYM2bImFhjT2mssWf80PEhVjuwKNBFpEcnzp5gbe1aqmqrWF2zmkNnPmwGz7Is7px0Z7Ck8JZrbol0Y09/pkAXkYu0d7RTfbA6aOx57b3X4hp7JgybEAT4wikLGTVkVIjVygUKdBEB4PCZw8Ga8DW1azhx9kQwlpuVy7ySeUGIz7x6ppYU9kMKdJEBqrW9lS0HtgSz8DcPvxk3Pnnk5A8beybPZ2je0JAqlWQp0EUGkLr364JZ+Lq96zjdejoYG5IzhPmT5wez8NLRpZqFZxgFukiEnT1/lk37NwVLCvcc3xM3PmPsjCDA5xTPYXDO4JAqlVRQoItEiLuz5/ieIMA37d9ES1tLMD580HDumXJP0NgzacSky7yaZBoFukiGO3XuFOv3rQ9CfH/T/rjxW6+5NWjsuX3i7WrsiTAFukiGcXfeOvJWEOCbD2ymraMtGC/IL2DR1EVBY8/VV10dYrWSTgp0kQxwvPk4a/euDbozD585HIxlWRazJ80OZuG3XHMLWZYVYrUSFgW6SD/U3tHOtoPbePHdF6mqrWLbe9vi7lY/cdjEIMAXTF6gxh4BFOgi/cbB0weDJYVra9dysuVkMJaXncecojlBiF839jotKZSLKNBFQtLa3srm+s1BY8+OIzvixktHlwZLCueVzOOqvKtCqlQyhQJdJI32ndwXBPj6fes503omGMvPzefuyXdTPrWcRaWLKB1dGmKlkokU6CJ9qPl8M5vqNgUh/ofjf4gbn3n1zGAW/smiTzIoZ1BIlUoUKNBFUsjdeefYO8G18E11mzjXfi4YHzFoBPdMvSeYhRcOLwyxWokaBbrIFWpqaWLdvnXBuvADpw7EjZdNKAtm4bcX3k5Olv7spG/oN0uklzq8gzcPvxkE+O8O/I52bw/Gx+aPZVHpIsqnlnPP1HvU2CNpo0AXSULjB41xjT1HPzgajGVbdtySwpvG36TGHgmFAl0kgbaONl5777VgFl59sDqusWfS8ElxjT0jBo8IsVqRGAW6SKf3Tr3H6trVVNVUsXbvWt5veT8Yy8vO467iu4IQn14wXY090u8o0GXAOtd2jlfrXw0uo+w8ujNufNroaUGA31V8lxp7pN9ToMuAUnuiNq6xp/l8czB2Ve5VLJiyIFhSOGXUlBArFek9BbpE2getH7CxbmMQ4jUnauLGbxh3Q7Ck8M5Jd6qxRzJaUoFuZuXAD4Fs4J/c/dvdxouAnwMjO/d53N1XpbhWkR65O7sbdwcB/vL+l2ltbw3GRw4eyb1T76V8auyzwicOnxhitSKp1WOgm1k28DRwD9AAbDOzle6+u8tu3wCed/cfm9kMYBVQ0gf1ilzk/Zb3Wbd3XRDiDacagjHDuG3ibcEs/BMTP6HGHomsZH6zbwNq3H0vgJk9BzwAdA10B4Z3Ph4BHExlkSJddXgH2w9tD5YUbm3YGtfYM+6qcXGNPQX5BSFWK5I+yQT6RKBrL3MDcHu3ff4GWGNmjwBXAQsTvZCZVQAVAEVFRb2tVQawox8cZU3tGqpqqlhTu4bG5sZgLCcrh7lFc4NZ+I3jb1RjjwxIyQR6osW23u35g8DP3P0pM7sD+Bczm+nuHXHf5L4cWA5QVlbW/TVEAm0dbWxt2BrMwl8/9HrceNGIIu4rvY/y0nLunnw3wwcNv8QriQwcyQR6AzCpy/NCLr6kshQoB3D3LWY2GCgAjiKSpANNB4LGnpf2vkTTuaZgbFD2IOaVzAvWhX9szMfU2CPSTTKBvg2YZmaTgfeAxcB/67ZPPbAA+JmZTQcGA42IXEZLW0vQ2FNVU8Wuxl1x4x8b87EgwOcWzyU/Nz+kSkUyQ4+B7u5tZvYwsJrYksRn3H2XmT0BVLv7SuAx4Kdm9mVil2P+zN11SUUuUnOiJgjwDXUb4hp7huYNZcHkBZSXlrNo6iImj5ocYqUimSep9Vuda8pXddv2zS6PdwOzU1uaRMGZ1jNs2LchWFK49+TeuPEbx90YzMLvnHQnedl5IVUqkvm0IFdSyt15++jbQYC/sv8VznecD8ZHDR4Va+wpjTX2TBg2IcRqRaJFgS5X7OTZk7y096UgxA+e/vA9c8OYVTgrWFJYNqGM7KzsEKsViS4FuvRah3fw+sHXgwDf2rCVji4rVMcPHR+7jDK1nIVTFjImf0yI1YoMHAp0ScqRM0dijT21scaeY83HgrGcrBzmFn/Y2HPDuBu0pFAkBAp0Seh8+3m2NGwJ7l6//dD2uPHiEcVBY8/8yfPV2CPSDyjQJVDfVB8sKVy3bx2nzp0KxgbnDI419nTOwq8dc61m4SL9jAJ9AGtpa+Hl/S8HIf7OsXfixj9e8PEgwOcWz2VI7pCQKhWRZCjQBxB3590T7wYBvrFuI2fbzgbjw/KGsXDKwqCxp3hkcYjVikhvKdAj7vS502yo2xCE+L7398WN3zz+5qCx547CO8jNzg2pUhG5Ugr0iHF3dhzZEdz4+NX6V+Mae8YMGRPX2DN+6PgQqxWRVFKgR8CJsydYW7uWqtoqVtes5tCZQ8FYlmVxR+EdwSz81mtuVWOPSEQp0DNQe0c71Qerg8ae1957La6x55qh1wQBvnDKQkYPGR1itSKSLgr0DHH4zOFgTfia2jWcOHsiGMvNymVeyTwWTV1EeWk51199vZYUigxACvR+qrW9lS0HtgSz8DcPvxk3Pnnk5LjGnqF5Q0OqVET6CwV6P1L3fl0wC1+3dx2nW08HY0NyhjB/8vxgXXjp6FLNwkUkjgI9RGfPn2XT/k3BksI9x/fEjc8YOyMI8DnFcxicMzikSkUkEyjQ08jd2XN8TxDgm/ZvoqWtJRgfPmh4rLFnajmLShdRNKIoxGpFJNMo0PvYqXOnWL9vfRDi+5v2x43fcs0twSx8VuEsNfaIyEemQE8xd+etI28FAb75wGbaOtqC8YL8AhZNXcSiqYu4d+q9jBs6LsRqRSRKFOgpcLz5OGv3rg26Mw+fORyMZVkWsyfNDtaF33LNLWRZVojVikhUKdB7sGLnCirXVVLfVE/RiCKWLVjG4usWs+3gNl5890WqaqvY9t42HA++Z+KwiUGAL5i8gFFDRoV4BiIyUJi797xXHygrK/Pq6upQjp2sFTtXUPFCBc3nm4Nt2ZbNoOxBNLd9uC0vO485RXOCEL9u7HVaUigifcLMXnf3skRjmqFfRuW6yrgwB2j3dprbmpkyakrQ2DOvZJ4ae0QkdAr0y6hvqk+43TBq/7I2zdWIiFye3p27jMLhhQm3a324iPRHCvTLmFM056Jt+bn5LFuwLIRqREQuT4F+CS1tLayvWw/A2PyxGEbxiGKWf3o5S65fEnJ1IiIX0zX0S3jmjWc4fOYwN42/ie0V27VqRUT6Pc3QE2htb+U7m78DQOWcSoW5iGQEBXoCv9jxC+qb6pleMJ0/nv7HYZcjIpKUpALdzMrNbI+Z1ZjZ45fY53NmttvMdpnZv6a2zPRp72jnb1/9WwD+es5fq01fRDJGj9fQzSwbeBq4B2gAtpnZSnff3WWfacDXgdnuftLMru6rgvva87uep+ZEDVNGTWHxzMVhlyMikrRkpp+3ATXuvtfdW4HngAe67fPnwNPufhLA3Y+mtsz06PAOlr0SW5L49U9+nZwsvWcsIpkjmUCfCBzo8ryhc1tX1wLXmtlmM9tqZuWJXsjMKsys2syqGxsbP1rFfejX//VrdjXuonB4IV+88YthlyMi0ivJBHqiJR7dP9ErB5gGzAMeBP7JzEZe9E3uy929zN3Lxo4d29ta+5S7861XvgXA12Z/jbzsvJArEhHpnWQCvQGY1OV5IXAwwT6/dvfz7r4P2EMs4DNGVU0V2w9tZ9xV41h689KwyxER6bVkAn0bMM3MJptZHrAYWNltn18B8wHMrIDYJZi9qSy0L7k7T778JACP3fEYQ3KHhFyRiEjv9Rjo7t4GPAysBt4Bnnf3XWb2hJnd37nbauC4me0GNgD/y92P91XRqbaxbiNbGrYweshoHip7KOxyREQ+kqSWcbj7KmBVt23f7PLYga90fmWcC9fOH739UYYNGhZyNSIiH82A75rZcmAL6/etZ/ig4Txy+yNhlyMi8pEN+EC/sO784U88zMjBFy3MERHJGAM60N849Aa/ffe35Ofm8+isR8MuR0TkigzoQL8wO3/o1ocYe1X/WhcvItJbAzbQdx3dxX+88x8Myh7EY3c+FnY5IiJXbMAG+oVPVFx681ImDJsQcjUiIlduQAZ6zYkann37WXKycvjq7K+GXY6ISEoMyED/9qvfpsM7+NMb/pTikcVhlyMikhIDLtDrm+r5+Vs/J8uyePyTCe/VISKSkQZcoH9v8/do62jj89d9nmvHXBt2OSIiKTOgAv3wmcP8dPtPgdjt5UREomRABfpTv3uKc+3n+MzHP8PMq2eGXY6ISEoNmEA/1nyMH1f/GIDKOZUhVyMiknoDJtB/uPWHfHD+A+4rvY9bJ9wadjkiIik3IAL9/Zb3+dFrPwLgG3O/EXI1IiJ9Y0AE+tOvPc2pc6eYVzKPOyfdGXY5IiJ9IvKBfqb1DN/f+n0AvjFHs3MRia7IB/pPqn/C8bPHmVU4i7sn3x12OSIifSbSgd7S1sLfbfk7IDY7N7OQKxIR6TuRDvRn3niGw2cOc9P4m/jUtE+FXY6ISJ+KbKC3trfync3fATQ7F5GBIbKB/osdv6C+qZ7pBdP5zPTPhF2OiEifi2Sgt3W0BTewqJxTSZZF8jRFROJEMume3/U8NSdqmDpqKp+f+fmwyxERSYvIBXqHdwQ3f378k4+Tk5UTckUiIukRuUD/1X/9it2NuykcXsgXb/xi2OWIiKRNpALd3fnWy98C4Guzv0Zedl7IFYmIpE+kAr2qpoo3Dr/BuKvGsfTmpWGXIyKSVpEJdHfnyZefBOCv7vwrhuQOCbkiEZH0ikygb6zbyJaGLYweMpqHyh4KuxwRkbSLTKB/65XYtfMvz/oyQ/OGhlyNiEj6JRXoZlZuZnvMrMbMHr/Mfp81MzezstSV2LPfHfgd6/etZ/ig4Tx828PpPLSISL/RY6CbWTbwNHAfMAN40MxmJNhvGPCXwO9TXWRPLqw7f+S2Rxg5eGS6Dy8i0i8kM0O/Dahx973u3go8BzyQYL8nge8CLSmsr0fbD21n1buryM/N59FZj6bz0CIi/UoygT4RONDleUPntoCZ3QxMcvffXO6FzKzCzKrNrLqxsbHXxSZyYXb+0K0PUZBfkJLXFBHJRMkEeqLPnfVg0CwL+D7wWE8v5O7L3b3M3cvGjh2bfJWXsOvoLv7znf9kUPYgHruzx8OLiERaMoHeAEzq8rwQONjl+TBgJrDRzOqAWcDKdLwxeuETFZfevJQJwyb09eFERPq1ZAJ9GzDNzCabWR6wGFh5YdDdm9y9wN1L3L0E2Arc7+7VfVJxp5oTNTz79rPkZOXw1dlf7ctDiYhkhB4D3d3bgIeB1cA7wPPuvsvMnjCz+/u6wO5W7FxByQ9KmPYP0+jwDmYXzqZ4ZHG6yxAR6XeS+mxZd18FrOq27ZuX2HfelZeV2IqdK6h4oYLm883Btt8f/D0rdq5gyfVL+uqwIiIZIaM6RSvXVcaFOUBLWwuV6ypDqkhEpP/IqECvb6rv1XYRkYEkowK9aERRr7aLiAwkGRXoyxYsIz83P25bfm4+yxYsC6kiEZH+I6MCfcn1S1j+6eXk58RCvSC/gOWfXq43REVEyLBAh1io31t6LwDL/0hhLiJyQcYFOkBbRxsAOVlJrboUERkQFOgiIhGhQBcRiYiMDvTsrOyQKxER6T8yMtDbO9oBzdBFRLrKyEDXJRcRkYsp0EVEIkKBLiISEQp0EZGIUKCLiESEAl1EJCIU6CIiEZHRgZ5taiwSEbkgIwO93dVYJCLSXUYGui65iIhcTIEuIhIRCnQRkYhQoIuIRIQCXUQkIhToIiIRkXGB3uEddHgHAFmWceWLiPSZjEvErje3MLOQqxER6T8yLtDVJSoikljGBbq6REVEEksq0M2s3Mz2mFmNmT2eYPwrZrbbzHaY2TozK059qTF6Q1REJLEeA93MsoGngfuAGcCDZjaj225vAGXufgPw78B3U13oBQp0EZHEkpmh3wbUuPted28FngMe6LqDu29w9+bOp1uBwtSW+SEFuohIYskE+kTgQJfnDZ3bLmUp8GKiATOrMLNqM6tubGxMvsouFOgiIoklE+iJ1gZ6wh3NvgCUAd9LNO7uy929zN3Lxo4dm3yVXSjQRUQSSyYVG4BJXZ4XAge772RmC4FK4C53P5ea8i6mQBcRSSyZGfo2YJqZTTazPGAxsLLrDmZ2M/AT4H53P5r6Mj+kQBcRSazHQHf3NuBhYDXwDvC8u+8ysyfM7P7O3b4HDAV+aWZvmtnKS7zcFVOgi4gkllQquvsqYFW3bd/s8nhhiuu6pKBTNEudoiIiXWVep2iHOkVFRBLJuEDXJRcRkcQU6CIiEaFAFxGJCAW6iEhEKNBFRCJCgS4iEhEKdBGRiMjYQNct6ERE4mVsoGuGLiISL+MCXfcUFRFJLOMCXTN0EZHEFOgiIhGhQBcRiQgFuohIRCjQRUQiQoEuIhIRCnQRkYjI2EBXp6iISLyMDXTN0EVE4mVcoOueoiIiiWVcoGuGLiKSmAJdRCQiFOgiIhGhQBcRiQgFuohIRCjQRUQiQoEuIhIRmRfo3tkpmqVOURGRrjIu0NVYJCKSWMYFui65iIgkllSgm1m5me0xsxozezzB+CAz+7fO8d+bWUmqCwVYsXMFK/esBOCRFx9hxc4VfXEYEZGM1GOgm1k28DRwHzADeNDMZnTbbSlw0t1Lge8D30l1oSt2rqDihQrOtp0F4FjzMSpeqFCoi4h0SmaGfhtQ4+573b0VeA54oNs+DwA/73z878ACM7PUlQmV6yppPt8ct635fDOV6ypTeRgRkYyVTKBPBA50ed7QuS3hPu7eBjQBY7q/kJlVmFm1mVU3Njb2qtD6pvpebRcRGWiSCfREM23/CPvg7svdvczdy8aOHZtMfYGiEUW92i4iMtAkE+gNwKQuzwuBg5fax8xygBHAiVQUeMGyBcvIz82P25afm8+yBctSeRgRkYyVTKBvA6aZ2WQzywMWAyu77bMS+O+djz8LrHf3i2boV2LJ9UtY/unlFI8oxjCKRxSz/NPLWXL9klQeRkQkY/W4mNvd28zsYWA1kA084+67zOwJoNrdVwL/D/gXM6shNjNf3BfFLrl+iQJcROQSkurOcfdVwKpu277Z5XEL8CepLU1ERHoj4zpFRUQkMQW6iEhEKNBFRCJCgS4iEhGW4tWFyR/YrBHY/xG/vQA4lsJyMoHOeWDQOQ8MV3LOxe6esDMztEC/EmZW7e5lYdeRTjrngUHnPDD01TnrkouISEQo0EVEIiJTA3152AWEQOc8MOicB4Y+OeeMvIYuIiIXy9QZuoiIdKNAFxGJiH4d6P3l5tTplMQ5f8XMdpvZDjNbZ2bFYdSZSj2dc5f9PmtmbmYZv8QtmXM2s891/qx3mdm/prvGVEvid7vIzDaY2Rudv9+fCqPOVDGzZ8zsqJm9fYlxM7Mfdf577DCzW674oO7eL7+IfVRvLTAFyAPeAmZ02+dLwD92Pl4M/FvYdafhnOcD+Z2P/2IgnHPnfsOAl4GtQFnYdafh5zwNeAMY1fn86rDrTsM5Lwf+ovPxDKAu7Lqv8JznArcAb19i/FPAi8Tu+DYL+P2VHrM/z9D7xc2p06zHc3b3De5+4W7ZW4ndQSqTJfNzBngS+C7Qks7i+kgy5/znwNPufhLA3Y+mucZUS+acHRje+XgEF98ZLaO4+8tc/s5tDwD/7DFbgZFmds2VHLM/B3rKbk6dQZI5566WEvsvfCbr8ZzN7GZgkrv/Jp2F9aFkfs7XAtea2WYz22pm5Wmrrm8kc85/A3zBzBqI3X/hkfSUFpre/r33KKkbXIQkZTenziBJn4+ZfQEoA+7q04r63mXP2cyygO8Df5augtIgmZ9zDrHLLvOI/V/YK2Y2093f7+Pa+koy5/wg8DN3f8rM7iB2F7SZ7t7R9+WFIuX51Z9n6P3i5tRplsw5Y2YLgUrgfnc/l6ba+kpP5zwMmAlsNLM6YtcaV2b4G6PJ/m7/2t3Pu/s+YA+xgM9UyZzzUuB5AHffAgwm9iFWUZXU33tv9OdA7xc3p06zHs+58/LDT4iFeaZfV4Ueztndm9y9wN1L3L2E2PsG97t7dTjlpkQyv9u/IvYGOGZWQOwSzN60VplayZxzPbAAwMymEwv0xrRWmV4rgS92rnaZBTS5+6EresWw3wnu4V3iTwF/IPbueGXntieI/UFD7Af+S6AGeA2YEnbNaTjnl4AjwJudXyvDrrmvz7nbvhvJ8FUuSf6cDfh7YDewE1gcds1pOOcZwGZiK2DeBO4Nu+YrPN9ngUPAeWKz8aXAQ8BDXX7GT3f+e+xMxe+1Wv9FRCKiP19yERGRXlCgi4hEhAJdRCQiFOgiIhGhQBcRiQgFuohIRCjQRUQi4v8DkfPo8ibhq+UAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "a = copy.deepcopy(model.tree.source) \n", + "plot_roc(a, X_guessed, y)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this example, we can see different ROC curve when optimizing different objectives. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/gosdt/model/gosdt.py b/gosdt/model/gosdt.py index 9869bee..3e58db7 100644 --- a/gosdt/model/gosdt.py +++ b/gosdt/model/gosdt.py @@ -117,7 +117,6 @@ def fit(self, X, y): self.configuration["theta"] = None elif self.configuration["objective"] == "f1": self.configuration["theta"] = None - self.configuration["w"] = None elif self.configuration["objective"] == "auc": self.configuration["theta"] = None self.configuration["w"] = None @@ -274,7 +273,7 @@ def __python_train__(self, X, y): trains a model using the GOSDT pure Python implementation modified from OSDT """ - encoder = Encoder(X.values[:,:], header=X.columns[:], mode="complete", target=y[y.columns[0]]) + encoder = Encoder(X.values[:,:], header=X.columns[:], mode="none", target=y[y.columns[0]]) headers = encoder.headers X = pd.DataFrame(encoder.encode(X.values[:,:]), columns=encoder.headers) @@ -314,13 +313,13 @@ def __python_train__(self, X, y): else: decoded_leaves = [] for leaf in leaves_c: - decoded_leaf = tuple((dic[j] if j > 0 else -dic[-j]) for j in leaf) + decoded_leaf = tuple((dic[j]+1 if j > 0 else -(dic[-j]+1)) for j in leaf) decoded_leaves.append(decoded_leaf) - source = self.__translate__(dict(zip(decoded_leaves, pred_c))) + source = self.__translate__(dict(zip(decoded_leaves, pred_c)), headers) self.tree = TreeClassifier(source, encoder=encoder) self.tree.__initialize_training_loss__(X, y) - def __translate__(self, leaves): + def __translate__(self, leaves, headers): """ Converts the leaves of OSDT into a TreeClassifier-compatible object """ @@ -334,8 +333,8 @@ def __translate__(self, leaves): else: features = {} for leaf in leaves.keys(): - if not leaf in features: - for e in leaf: + for e in leaf: + if not abs(e) in features: features[abs(e)] = 1 else: features[abs(e)] += 1 @@ -352,13 +351,15 @@ def __translate__(self, leaves): positive_leaves[tuple(s for s in leaf if s != split)] = prediction else: negative_leaves[tuple(s for s in leaf if s != -split)] = prediction + if split != None: + split = split-1 return { "feature": split, - "name": "feature_" + str(split), + "name": headers[split], "reference": 1, "relation": "==", - "true": self.__translate__(positive_leaves), - "false": self.__translate__(negative_leaves), + "true": self.__translate__(positive_leaves, headers), + "false": self.__translate__(negative_leaves, headers), } def __translate_cart__(self, tree, id=0, depth=-1):