forked from dwf/glmnet-python
-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Basic functionality for ElasticNet and LogNet.
- Loading branch information
Showing
6 changed files
with
277 additions
and
289 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
f2py -c --fcompiler=gnu95 --f77flags='-fdefault-real-8' --f90flags='-fdefault-real-8' glmnet.pyf glmnet.f |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,48 +1,76 @@ | ||
import numpy as np | ||
from glmnet import elastic_net | ||
|
||
class ElasticNet(object): | ||
"""ElasticNet based on GLMNET""" | ||
def __init__(self, alpha, rho=0.2): | ||
super(ElasticNet, self).__init__() | ||
self.alpha = alpha | ||
self.rho = rho | ||
self.coef_ = None | ||
self.rsquared_ = None | ||
|
||
def fit(self, X, y): | ||
n_lambdas, intercept_, coef_, _, _, rsquared_, lambdas, _, jerr \ | ||
= elastic_net(X, y, self.rho, lambdas=[self.alpha]) | ||
# elastic_net will fire exception instead | ||
# assert jerr == 0 | ||
self.coef_ = coef_ | ||
self.intercept_ = intercept_ | ||
self.rsquared_ = rsquared_ | ||
return self | ||
import matplotlib | ||
import matplotlib.pyplot as plt | ||
import _glmnet | ||
from glmnet import GlmNet | ||
|
||
class ElasticNet(GlmNet): | ||
|
||
def _fit(self, X, y): | ||
# Predictors and response | ||
X = np.asanyarray(X) | ||
y = np.asanyarray(y) | ||
|
||
# Make a copy if we are not able to overwrite X with its standardized | ||
# version. Note that if X is not fortran contiguous, then it will be | ||
# copied anyway. | ||
if np.isfortran(X) and not self.overwrite_pred_ok: | ||
X = X.copy(order='F') | ||
|
||
# The target array will usually be overwritten with its standardized | ||
# version, if this is not ok, we should copy. | ||
if not self.overwrite_targ_ok: | ||
y = y.copy() | ||
|
||
# Setup is complete, call the wrapper | ||
(self._out_n_lambdas, | ||
self.intercepts, | ||
self._comp_coef, | ||
self._p_comp_coef, | ||
self._n_comp_coef, | ||
self.r_sqs, | ||
self.out_lambdas, | ||
self._n_passes, | ||
self._error_flag) = _glmnet.elnet(self.alpha, | ||
X, | ||
y, | ||
self.weights, | ||
self.excl_preds, | ||
self.rel_penalties, | ||
self.max_vars_all, | ||
self.frac_lg_lambda, | ||
self.lambdas, | ||
self.thresh, | ||
nlam=self.n_lambdas | ||
) | ||
self._indicies = np.trim_zeros(self._p_comp_coef, 'b') - 1 | ||
|
||
# Check for errors, documented in glmnet.f. | ||
if self._error_flag != 0: | ||
if self._error_flag == 10000: | ||
raise ValueError('cannot have max(vp) < 0.0') | ||
elif self._error_flag == 7777: | ||
raise ValueError('all used predictors have 0 variance') | ||
elif self._error_flag < 7777: | ||
raise MemoryError('elnet() returned error code %d' % jerr) | ||
else: | ||
raise Exception('unknown error: %d' % jerr) | ||
|
||
@property | ||
def coefficients(self): | ||
return self._comp_coef[:np.max(self._n_comp_coef), | ||
:self._out_n_lambdas | ||
] | ||
|
||
def predict(self, X): | ||
return np.dot(X, self.coef_) + self.intercept_ | ||
|
||
def __str__(self): | ||
n_non_zeros = (np.abs(self.coef_) != 0).sum() | ||
return ("%s with %d non-zero coefficients (%.2f%%)\n" + \ | ||
" * Intercept = %.7f, Lambda = %.7f\n" + \ | ||
" * Training r^2: %.4f") % \ | ||
(self.__class__.__name__, n_non_zeros, | ||
n_non_zeros / float(len(self.coef_)) * 100, | ||
self.intercept_[0], self.alpha, self.rsquared_[0]) | ||
|
||
|
||
def elastic_net_path(X, y, rho, **kwargs): | ||
"""return full path for ElasticNet""" | ||
n_lambdas, intercepts, coefs, _, _, _, lambdas, _, jerr \ | ||
= elastic_net(X, y, rho, **kwargs) | ||
return lambdas, coefs, intercepts | ||
|
||
def Lasso(alpha): | ||
"""Lasso based on GLMNET""" | ||
return ElasticNet(alpha, rho=1.0) | ||
|
||
def lasso_path(X, y, **kwargs): | ||
"""return full path for Lasso""" | ||
return elastic_net_path(X, y, rho=1.0, **kwargs) | ||
return self.intercepts + np.dot(X[:, self._indicies], | ||
self.coefficients | ||
) | ||
|
||
def _plot_path(self): | ||
plt.clf() | ||
xvals = np.log(self.out_lambdas[1:self._out_n_lambdas]) | ||
for coef_path in self.coefficients: | ||
plt.plot(xvals, coef_path[1:]) | ||
plt.show() | ||
|
Oops, something went wrong.