たそらぼ

日頃思ったこととかメモとか。

lassoを実装した話

だいぶ前ですが、lassoに関心があり、書籍を参考に実装しました。
参考にした書籍は以下です。

スパース推定法による統計モデリング (統計学One Point)

スパース推定法による統計モデリング (統計学One Point)

lassoとは?:
lasso(least absolute shrinkage and selection operation)は、線形回帰モデル
{ \displaystyle 
 \boldsymbol{y} = \boldsymbol{X} \boldsymbol{\beta}
}
みたいなのを考えた際に、L1罰則項をつけることで、データから自動的に変数xを選択してくれるアルゴリズムです。
実行すると分かりますが、不要とされたxが綺麗に0になるため、スパース推定の代表的な方法となっています。

具体的には、正則化問題
{ \displaystyle S_\lambda (\boldsymbol{\beta}) = \| \boldsymbol{y}- \boldsymbol{X} \boldsymbol{\beta} \|^2_2 + \lambda \|\boldsymbol{\beta} \|_1 }
(ここで{ \displaystyle \|  \cdot \|_1} はL1ノルム、{ \displaystyle \| \cdot \|_2} はL2ノルム)
について、{ \displaystyle S_\lambda (\boldsymbol{\beta})}が最小になるような{\boldsymbol{\beta})}を最適化手法により探します。

今回はこれを一番簡単な座標降下法で解きました。
座標降下法を始めとする最適化手法は、以下の『Distributed Optimization and Statistical Learning via the Alternating Direction Method of Multipliers』に詳しくまとまっています。
Distributed Optimization and Statistical Learning via the Alternating Direction Method of Multipliers

def Lasso(X, y):
    #パラメータ 
    param = 1. 
    threshould = 0.001
    maxcyc = 100

    #betaの初期値、この場合は[1,1,…,1]で良いようです。
    betatmp = np.ones( X.shape[1] )

    #座標降下法の更新式に従って、betaを更新する。
    for ii in range( maxcyc ):
        beta = np.array( betatmp )
        for jj in range( X.shape[1] ):
            beta[jj] = estBeta_CD(X, y, beta, jj, param ) 
            if abs( sum( beta - betatmp ) ) < threshould:
                print(beta)
                return beta
            else:
                betatmp = np.array( beta )

    #betaが収束しなければエラーで終了する。
    print('ERROR!! ITERATION NOT COMPLETED!!')
    return beta

#新しいbetaの計算
def estBeta_CD(X, y, betatmp, num, param):
    """
    estimate beta by coordinate descent method
    """
    Xtmp = np.delete(X,num,1)
    xj = X.T[num]
    betatmp = np.delete(betatmp,num) 
    r = y - np.dot(Xtmp, betatmp)
    return Soft_threshold_operator(xj, r, param)

#軟閾値作用素
def Soft_threshold_operator(xj, y, param):
    """return the value of soft thereshould operator"""
    x =  np.dot(xj, y) / len(y)
    if x > 0:
        sign = 1
    elif x == 0:
        sign = 0
    else:
        sign = -1
    return sign * max( abs(x) - param, 0)

疑似コードも載せれれば分かりやすいのですが、著作権的にOKか分からないので、気になる方はぜひ本の方を見て見てください。