Skip to content

Commit

Permalink
minor fix
Browse files Browse the repository at this point in the history
  • Loading branch information
gykovacs committed Dec 19, 2023
1 parent a4623b0 commit 60ff996
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
2 changes: 1 addition & 1 deletion smote_variants/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
@author: gykovacs
"""

__version__= '1.0.0'
__version__= '1.0.1'
20 changes: 14 additions & 6 deletions smote_variants/undersampling/_oversampling_driven_undersampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,18 @@ class OversamplingDrivenUndersampling(UnderSampling):
The oversampling driven undersampling
"""

def __init__(self, oversampler_specification, random_state=None):
def __init__(self, oversampler_specification, mode="random", random_state=None):
"""
The constructor of the oversampling driven undersampling
Args:
oversampler_specification (tuple): the specification of the oversampler
random_state (None|int|np.random.RandomState): the random seed or state to be used
mode (str): 'random'/'farthest' - the mode of sample removal
"""
UnderSampling.__init__(self, random_state=random_state)
self.oversampler = instantiate_obj(oversampler_specification)
self.mode = mode

@classmethod
def parameter_combinations(cls, raw=False):
Expand All @@ -37,7 +39,8 @@ def parameter_combinations(cls, raw=False):
("smote_variants", "SMOTE", {"random_state": 5}),
("smote_variants", "ADASYN", {"random_state": 5}),
("smote_variants", "Borderline_SMOTE1", {"random_state": 5}),
]
],
"mode": ["random", "farthest"],
}

return cls.generate_parameter_combinations(parameter_combinations, raw)
Expand All @@ -63,11 +66,16 @@ def sample(self, X, y):
dists = X_new - X_maj[:, None]

dists = np.mean(np.sqrt(np.sum((dists) ** 2, axis=2)), axis=1)
dists = 1.0 / dists
p = dists / np.sum(dists)

mask = self.random_state.choice(np.arange(n_maj), n_min, p=p, replace=False)
X_maj = X_maj[mask]
if self.mode == "random":
inv_dists = 1.0 / np.where(dists < 1e-8, 1e-8, dists)
p = dists / np.sum(inv_dists)

mask = self.random_state.choice(np.arange(n_maj), n_min, p=p, replace=False)
X_maj = X_maj[mask]
elif self.mode == "farthest":
sorting = np.argsort(dists)
X_maj = X_maj[sorting[:n_min]]

X_res = np.vstack([X_maj, X_min]) # pylint: disable=invalid-name
y_res = np.hstack([np.repeat(0, n_min), np.repeat(1, n_min)])
Expand Down

0 comments on commit 60ff996

Please sign in to comment.