Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH add ADMM solver #76

Open
mathurinm opened this issue May 6, 2022 · 0 comments
Open

ENH add ADMM solver #76

mathurinm opened this issue May 6, 2022 · 0 comments

Comments

@mathurinm
Copy link
Collaborator

@josephsalmon here's a basic ADMM solver for the Lasso:

import numpy as np
from numpy.linalg import norm
import matplotlib.pyplot as plt

from scipy import io
from sklearn.preprocessing import MinMaxScaler

from skglm.utils import ST_vec


def primal_lasso(X, y, alpha, w):
    """Compute the primal objective of Lasso.

    Parameters
    ----------
    X : array, shape (n_samples, n_features)
        Design matrix.

    y : array, shape (n_samples,)
        Target vector.

    alpha : float
        Regularization parameter.

    w : array, shape (n_features)
        Coefficient vector.

    Returns
    -------
    p_obj : float
        The primal objective.
    """
    r = X @ w - y
    return 1/2 * r @ r + alpha * norm(w, ord=1)


def extrapolate(last_K_u, K, extrap_type="LPinf", s=100):
    """Construct an extrapolated point from past iterates.

    Parameters
    ----------
    last_K_u : array, shape (n_features, K)
        Array of past iterates.

    K : int
        Number of past iterates to use for extrapolation.

    extrap_type : str, (`LP` | `LPinf`), optional
        Extrapolation type.

    s : int
        LP steps (used if extra_type=`LP`).

    Returns
    -------
    w : array, shape (n_features,)
        Extrapolated vectors
    """
    q = K - 2

    V = np.diff(last_K_u, 1, axis=-1)

    V_k = V[:, 1:]
    v_k = V[:, -1]
    V_prev = V[:, :-1]

    # Compute coefficient
    VtV = V_prev.T @ V_prev

    try:
        c = np.linalg.solve(VtV, V_prev.T @ v_k)
    except np.linalg.LinAlgError:
        return v_k
    else:
        # Iteration matrix
        C = np.diag(np.ones(q-1), -1)
        C[:, -1] = c

        rho = norm(np.linalg.eigvals(C), ord=np.inf)

        if extrap_type == "LPinf":
            if rho < 1:
                tmp = np.eye(q) - C
                S = np.linalg.solve(tmp.T, C.T).T
            else:
                S = 0 * C
        else:
            if rho < 1:
                pC = np.linalg.matrix_power(C, s)
                tmp = np.eye(q) - C
                S = np.linalg.solve(tmp.T, (C - pC).T, ).T
            else:
                S = 0 * C

        return V_k @ S[:, -1]


def admm(X, y, alpha, gamma=2., max_iter=1000, tol=1e-5, check_gap_freq=50, a=0, K=6,
         use_accel=True, verbose=True):
    """Run Alternate Direction Method of multipliers optimization scheme for Lasso.

    Parameters
    ----------
    X : array, shape (n_samples, n_features)
        Design matrix.

    y : array, shape (n_samples,)
        Target vector.

    alpha : float
        Regularization parameter.

    gamma : float
        Augmented Lagrangian parameter.

    max_iter : int
        Maximum number of iterations.

    tol : float
        Tolerance.

    check_gap_freq : int
        Frequency for checking convergence.

    a : float
        Inertia parameter.

    K : int
        Number of past iterates to compute extrapolated point.

    use_accel : bool
        Use extrapolation.

    verbose : bool
        Verbosity.

    Returns
    -------
    w : array, shape (n_features,)
        Coefficient vector.
    """
    n_features = X.shape[1]
    residuals = []
    iterates = []

    # Acceleration variables
    last_K_u = np.zeros((n_features, K))

    # Optimization variables
    w = np.ones(n_features)  # Primal iterates
    z = np.ones(n_features)
    psi = np.ones(n_features)  # Dual iterates
    u = psi + gamma * w
    u_bar = u

    v = u - u

    # Pre-compute useful quantities
    XtX_scaled = X.T @ X / gamma
    Xty_scaled = X.T @ y / gamma
    L = np.linalg.cholesky(XtX_scaled + np.eye(n_features))
    U = L.T

    for iter in range(1, max_iter + 1):
        u_prev = u.copy()

        # Proximal step for datafit
        z = np.linalg.solve(U, np.linalg.solve(L, Xty_scaled + u_bar / gamma))
        psi = u_bar - gamma * z  # Dual update

        # Proximal step for pen
        w = ST_vec((u_bar - 2 * psi) / gamma, alpha / gamma)
        u = psi + gamma * w
        iterates.append(w)

        # Inertial step
        v = u - u_prev
        u_bar = u + a * v

        last_K_u = np.column_stack((last_K_u[:, 1:], u))

        if use_accel and iter % (K + 1) == 0:
            e = extrapolate(last_K_u, K)
            with np.errstate(divide="ignore"):
                # Removes warning for zero division at first iteration
                # Parameter safeguard - avoid numerical errors
                coeff = np.minimum(1., 1e5 / (iter**1.1 * norm(e)))
            u = u + coeff * e
            u_bar = u

        res = norm(v)
        residuals.append(res)

        if iter % check_gap_freq == 0:
            p_obj = primal_lasso(X, y, alpha, w)
            if verbose:
                print(f"iter {iter} :: residual {res:.5f} :: obj {p_obj:.4f}")

            if res < tol:
                break
    return w, residuals, iterates


if __name__ == "__main__":
    # Matrices can be downloaded at:
    # https://github.com/jliang993/A3DMM/tree/master/codes/data
    X = io.loadmat('covtype_sample.mat')["h"]
    y = io.loadmat('covtype_label.mat')["l"]

    y = np.ravel(y)

    scaler = MinMaxScaler(feature_range=(-1, 1))
    X = X.toarray()
    X = scaler.fit_transform(X)

    alpha = 1
    w, residuals, iterates = admm(X, y, alpha, tol=1e-6, use_accel=True,
                                  max_iter=50_000, check_gap_freq=10)
    print("#" * 25)
    w_no_accel, residuals_no_acc, iterates_no_acc = admm(X, y, alpha, tol=1e-6,
                                                         use_accel=False,
                                                         max_iter=50_000,
                                                         check_gap_freq=100)

    np.testing.assert_allclose(w, w_no_accel, rtol=1e-4)

    # Plotting
    norms_accel = list(map(lambda wc: np.log(norm(wc - w)), iterates))
    norms_no_accel = list(map(lambda wc: np.log(
        norm(wc - w_no_accel)), iterates_no_acc))
    plt.plot(norms_accel, label="Accelerated")
    plt.plot(norms_no_accel, label="No accel")
    plt.legend()
    plt.title("ADMM - ||x - x^*||")
    plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant