-
Notifications
You must be signed in to change notification settings - Fork 27
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Re-write to fix references to old name, "SMITE"
- Loading branch information
1 parent
0cc4c2e
commit a2ec126
Showing
11 changed files
with
41 additions
and
44 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
[run] | ||
source = smite | ||
source = smrt | ||
omit = | ||
*/python?.?/* | ||
*/lib-python/?.?/*.py | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,12 +2,11 @@ | |
# | ||
# Author: Taylor Smith <[email protected]> | ||
# | ||
# Setup the SMITE module | ||
# Setup the SMRT module | ||
|
||
from __future__ import print_function, absolute_import, division | ||
from distutils.command.clean import clean | ||
from setuptools import setup | ||
import numpy as np | ||
import os | ||
import sys | ||
|
||
|
@@ -17,12 +16,12 @@ | |
import builtins | ||
|
||
# Hacky, adopted from sklearn. This sets a global variable | ||
# so smite __init__ can detect if it's being loaded in the setup | ||
# so smrt __init__ can detect if it's being loaded in the setup | ||
# routine, so it won't load submodules that haven't yet been built. | ||
builtins.__SMITE_SETUP__ = True | ||
builtins.__SMRT_SETUP__ = True | ||
|
||
# metadata | ||
DISTNAME = 'smite' | ||
DISTNAME = 'smrt' | ||
DESCRIPTION = 'Handle class imbalance intelligently by using autoencoders ' \ | ||
'to generate synthetic observations of your minority class.' | ||
|
||
|
@@ -31,8 +30,8 @@ | |
LICENSE = 'new BSD' | ||
|
||
# import restricted version | ||
import smite | ||
VERSION = smite.__version__ | ||
import smrt | ||
VERSION = smrt.__version__ | ||
|
||
# get the installation requirements: | ||
with open('requirements.txt') as req: | ||
|
@@ -43,7 +42,7 @@ | |
class CleanCommand(clean): | ||
description = "Remove build artifacts from the source tree" | ||
|
||
# this is mostly in case we ever add a Cython module to SMITE | ||
# this is mostly in case we ever add a Cython module to SMRT | ||
def run(self): | ||
clean.run(self) | ||
# Remove c files if we are not within a sdist package | ||
|
@@ -97,7 +96,7 @@ def run(self): | |
'Operating System :: MacOS', | ||
'Programming Language :: Python :: 2.7' | ||
], | ||
keywords='sklearn scikit-learn sknn scikit-neuralnetwork auto-encoders class-imbalance', | ||
keywords='sklearn scikit-learn tensorflow scikit-neuralnetwork auto-encoders class-imbalance', | ||
packages=[DISTNAME], | ||
install_requires=REQUIREMENTS, | ||
cmdclass=cmdclass) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# | ||
# Author: Taylor Smith <[email protected]> | ||
# | ||
# The SMITE module | ||
# The SMRT module | ||
|
||
import sys | ||
import os | ||
|
@@ -13,18 +13,18 @@ | |
# this var is injected in the setup build to enable | ||
# the retrieval of the version number without actually | ||
# importing the un-built submodules. | ||
__SMITE_SETUP__ | ||
__SMRT_SETUP__ | ||
except NameError: | ||
__SMITE_SETUP__ = False | ||
__SMRT_SETUP__ = False | ||
|
||
if __SMITE_SETUP__: | ||
sys.stderr.write('Partial import of SMITE during the build process.' + os.linesep) | ||
if __SMRT_SETUP__: | ||
sys.stderr.write('Partial import of SMRT during the build process.' + os.linesep) | ||
else: | ||
__all__ = [ | ||
'autoencode.py', | ||
'balance' | ||
] | ||
|
||
# top-level imports | ||
from .balance import smite_balance | ||
from .balance import smrt_balance | ||
from .autoencode import AutoEncoder |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# | ||
# Author: Taylor Smith <[email protected]> | ||
# | ||
# The SMITE balancer | ||
# The SMRT balancer | ||
|
||
from __future__ import division, absolute_import, division | ||
from numpy.random import RandomState | ||
|
@@ -17,39 +17,34 @@ | |
import numpy as np | ||
|
||
__all__ = [ | ||
'smite_balance' | ||
'smrt_balance' | ||
] | ||
|
||
MAX_N_CLASSES = 100 # max unique classes in y | ||
MIN_N_SAMPLES = 2 # min n_samples per class in y | ||
|
||
|
||
def _validate_layers(layers): | ||
if not all(isinstance(x, LayerParameters) for x in layers): | ||
raise ValueError('layers should be a list, tuple or dict of smite.balance.LayerParameters') | ||
|
||
|
||
def _validate_ratios(ratio, name): | ||
if not 0. < ratio <= 1.: | ||
raise ValueError('Expected 0 < %s <= 1, but got %r' | ||
% (name, ratio)) | ||
|
||
|
||
def smite_balance(X, y, return_encoders=False, balance_ratio=0.2, eps=1.0, random_state=None, | ||
parameters=None, learning_rule='sgd', learning_rate=0.01, learning_momentum=0.9, batch_size=1, | ||
n_iter=None, n_stable=10, f_stable=0.001, valid_set=None, valid_size=0.0, normalize=None, regularize=None, | ||
weight_decay=None, dropout_rate=None, loss_type=None, callback=None, debug=False, verbose=None, | ||
**auto_encoder_params): | ||
"""SMITE (Sythetic Minority Interpolation TEchnique) is the younger, more sophisticated cousin to | ||
SMOTE (Synthetic Minority Oversampling TEchnique). Using auto-encoders, SMITE learns the parameters | ||
def smrt_balance(X, y, return_encoders=False, balance_ratio=0.2, eps=1.0, random_state=None, | ||
parameters=None, learning_rule='sgd', learning_rate=0.01, learning_momentum=0.9, batch_size=1, | ||
n_iter=None, n_stable=10, f_stable=0.001, valid_set=None, valid_size=0.0, normalize=None, regularize=None, | ||
weight_decay=None, dropout_rate=None, loss_type=None, callback=None, debug=False, verbose=None, | ||
**auto_encoder_params): | ||
"""SMRT (Sythetic Minority Reconstruction Technique) is the younger, more sophisticated cousin to | ||
SMOTE (Synthetic Minority Oversampling TEchnique). Using auto-encoders, SMRT learns the parameters | ||
that best reconstruct the observations in each minority class, and then generates synthetic observations | ||
until the minority class is represented at a minimum of ``balance_ratio`` * majority_class_size. | ||
SMITE avoids one of SMOTE's greatest risks: In SMOTE, when drawing random observations from whose k-nearest | ||
SMRT avoids one of SMOTE's greatest risks: In SMOTE, when drawing random observations from whose k-nearest | ||
neighbors to reconstruct, the possibility exists that a "border point," or an observation very close to | ||
the decision boundary may be selected. This could result in the synthetically-generated observations lying | ||
too close to the decision boundary for reliable classification, and could lead to the degraded performance | ||
of an estimator. SMITE avoids this risk, by ranking observations according to their reconstruction MSE, and | ||
of an estimator. SMRT avoids this risk, by ranking observations according to their reconstruction MSE, and | ||
drawing samples to reconstruct from the lowest-MSE observations (i.e., the most "phenotypical" of a class). | ||
Parameters | ||
|
@@ -233,14 +228,14 @@ def smite_balance(X, y, return_encoders=False, balance_ratio=0.2, eps=1.0, rando | |
y_type = type_of_target(y) | ||
supported_types = {'multiclass', 'binary'} | ||
if y_type not in supported_types: | ||
raise ValueError('SMITE balancer only supports %r, but got %r' % (supported_types, y_type)) | ||
raise ValueError('SMRT balancer only supports %r, but got %r' % (supported_types, y_type)) | ||
|
||
present_classes, counts = np.unique(y, return_counts=True) | ||
n_classes = len(present_classes) | ||
|
||
# ensure <= MAX_N_CLASSES | ||
if n_classes > MAX_N_CLASSES: | ||
raise ValueError('SMITE balancer currently only supports a maximum of %i ' | ||
raise ValueError('SMRT balancer currently only supports a maximum of %i ' | ||
'unique class labels, but %i were identified.' % (MAX_N_CLASSES, n_classes)) | ||
|
||
# check layers: | ||
|
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters