Skip to content

Commit

Permalink
Merge pull request scikit-learn#2426 from larsmans/sgd-improvements
Browse files Browse the repository at this point in the history
[MRG] SGD Cython improvements
  • Loading branch information
larsmans committed Sep 21, 2013
2 parents cb7b423 + b8e1a6c commit aa2d045
Show file tree
Hide file tree
Showing 9 changed files with 3,319 additions and 2,968 deletions.
2,031 changes: 940 additions & 1,091 deletions sklearn/linear_model/sgd_fast.c

Large diffs are not rendered by default.

53 changes: 27 additions & 26 deletions sklearn/linear_model/sgd_fast.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,18 @@ import numpy as np
import sys
from time import time

cimport cython
from libc.math cimport exp, log, sqrt, pow, fabs
cimport numpy as np
cimport cython
cdef extern from "numpy/npy_math.h":
bint isfinite "npy_isfinite"(double) nogil

from sklearn.utils.weight_vector cimport WeightVector
from sklearn.utils.seq_dataset cimport SequentialDataset

np.import_array()


ctypedef np.float64_t DOUBLE
ctypedef np.int32_t INTEGER


# Penalty constans
DEF NO_PENALTY = 0
DEF L1 = 1
Expand Down Expand Up @@ -323,7 +321,7 @@ cdef class SquaredEpsilonInsensitive(Regression):
return SquaredEpsilonInsensitive, (self.epsilon,)


def plain_sgd(np.ndarray[DOUBLE, ndim=1, mode='c'] weights,
def plain_sgd(np.ndarray[double, ndim=1, mode='c'] weights,
double intercept,
LossFunction loss,
int penalty_type,
Expand Down Expand Up @@ -401,29 +399,29 @@ def plain_sgd(np.ndarray[DOUBLE, ndim=1, mode='c'] weights,

cdef WeightVector w = WeightVector(weights)

cdef DOUBLE * x_data_ptr = NULL
cdef INTEGER * x_ind_ptr = NULL
cdef double *x_data_ptr = NULL
cdef int *x_ind_ptr = NULL

# helper variable
cdef int xnnz
cdef double eta = 0.0
cdef double p = 0.0
cdef double update = 0.0
cdef double sumloss = 0.0
cdef DOUBLE y = 0.0
cdef DOUBLE sample_weight
cdef double y = 0.0
cdef double sample_weight
cdef double class_weight = 1.0
cdef unsigned int count = 0
cdef unsigned int epoch = 0
cdef unsigned int i = 0
cdef int is_hinge = isinstance(loss, Hinge)

# q vector is only used for L1 regularization
cdef np.ndarray[DOUBLE, ndim = 1, mode = "c"] q = None
cdef DOUBLE * q_data_ptr = NULL
cdef np.ndarray[double, ndim = 1, mode = "c"] q = None
cdef double * q_data_ptr = NULL
if penalty_type == L1 or penalty_type == ELASTICNET:
q = np.zeros((n_features,), dtype=np.float64, order="c")
q_data_ptr = <DOUBLE * > q.data
q_data_ptr = <double * > q.data
cdef double u = 0.0

if penalty_type == L2:
Expand Down Expand Up @@ -502,23 +500,25 @@ def plain_sgd(np.ndarray[DOUBLE, ndim=1, mode='c'] weights,
print("Total training time: %.2f seconds." % (time() - t_start))

# floating-point under-/overflow check.
if np.any(np.isinf(weights)) or np.any(np.isnan(weights)) \
or np.isnan(intercept) or np.isinf(intercept):
if (not isfinite(intercept)
or any_nonfinite(<double *>weights.data, n_features)):
raise ValueError("floating-point under-/overflow occurred.")

w.reset_wscale()

return weights, intercept


cdef inline double max(double a, double b):
return a if a >= b else b
cdef bint any_nonfinite(double *w, int n):
cdef int i

for i in range(n):
if not isfinite(w[i]):
return True
return 0

cdef inline double min(double a, double b):
return a if a <= b else b

cdef double sqnorm(DOUBLE * x_data_ptr, INTEGER * x_ind_ptr, int xnnz):
cdef double sqnorm(double * x_data_ptr, int * x_ind_ptr, int xnnz) nogil:
cdef double x_norm = 0.0
cdef int j
cdef double z
Expand All @@ -527,8 +527,9 @@ cdef double sqnorm(DOUBLE * x_data_ptr, INTEGER * x_ind_ptr, int xnnz):
x_norm += z * z
return x_norm

cdef void l1penalty(WeightVector w, DOUBLE * q_data_ptr,
INTEGER * x_ind_ptr, int xnnz, double u):

cdef void l1penalty(WeightVector w, double * q_data_ptr,
int *x_ind_ptr, int xnnz, double u) nogil:
"""Apply the L1 penalty to each updated feature
This implements the truncated gradient approach by
Expand All @@ -538,16 +539,16 @@ cdef void l1penalty(WeightVector w, DOUBLE * q_data_ptr,
cdef int j = 0
cdef int idx = 0
cdef double wscale = w.wscale
cdef double * w_data_ptr = w.w_data_ptr
cdef double *w_data_ptr = w.w_data_ptr
for j in range(xnnz):
idx = x_ind_ptr[j]
z = w_data_ptr[idx]
if (wscale * w_data_ptr[idx]) > 0.0:
if wscale * w_data_ptr[idx] > 0.0:
w_data_ptr[idx] = max(
0.0, w_data_ptr[idx] - ((u + q_data_ptr[idx]) / wscale))

elif (wscale * w_data_ptr[idx]) < 0.0:
elif wscale * w_data_ptr[idx] < 0.0:
w_data_ptr[idx] = min(
0.0, w_data_ptr[idx] + ((u - q_data_ptr[idx]) / wscale))

q_data_ptr[idx] += (wscale * (w_data_ptr[idx] - z))
q_data_ptr[idx] += wscale * (w_data_ptr[idx] - z)
Loading

0 comments on commit aa2d045

Please sign in to comment.