Skip to content

Commit

Permalink
regularized diagonal LDA #204
Browse files Browse the repository at this point in the history
  • Loading branch information
szcf-weiya committed Aug 18, 2019
1 parent ddebe6a commit 4bedd76
Showing 1 changed file with 58 additions and 0 deletions.
58 changes: 58 additions & 0 deletions code/LDA/diagonalLDA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,44 @@ function DiagLDA(X::Array{Float64, 2}, y::Array{Int})
return δ
end

# regularized version of the diagonal-covariance form of LDA
function RegDiagLDA(X::Array{Float64, 2}, y::Array{Int, 1}, Δ::Float64)
# number of cells & number of observations
p, N = size(X)
# number of class
K = length(unique(y))
# array of discriminant functions
δ = Function[]
# pooled standard deviation
s = std(X, dims = 2)[:] # convert Nx1 matrix to N vector
s0 = median(s) # typically choice
# overall mean
x_bar = mean(X, dims = 2)[:]
# for each class
for k = 1:K
# obs in class k
Xk = X[:, convert(Array{Bool, 2}, y'.==k)[:]]
# mean
xk_bar = mean(Xk, dims = 2)[:]
# shrinkage
mk = sqrt( 1 / sum(y.==k) - 1 / N )
dk = ( xk_bar - x_bar ) ./ (mk * (s .+ s0))
dk_prime = soft_threshold(dk, Δ)
xk_bar_prime = x_bar + mk*(s .+ s0) .* dk_prime
# discriminant function
push!(δ, x-> -sum( ( (x-xk_bar_prime)./s ).^2 ) ) # equal π_k
end
return δ
end

function soft_threshold(x::Float64, Δ::Float64)
return sign(x) * max(0, abs(x) - Δ)
end

function soft_threshold(x::Array{Float64, 1}, Δ::Float64)
return [soft_threshold(xi, Δ) for xi in x]
end

# classify a single observation
function classify(x::Array{Float64, 1}, δ::Array{Function, 1})
scores = [δ[k](x) for k in 1:length(δ)]
Expand All @@ -57,6 +95,7 @@ end
cl = classify(xtrain, δ)
cltest = classify(xtest, δ)


using FreqTables
# train results
freqtable(cl, ytrain[1, :])
Expand All @@ -68,3 +107,22 @@ freqtable(cl, ytrain[1, :])

# test results
freqtable(cltest, ytest[1,:])






δ2 = RegDiagLDA(xtrain, ytrain[:], 2.0)
cl2 = classify(xtrain, δ2)
cltest2 = classify(xtest, δ2)
# RegDiagLDA
freqtable(cl2, ytrain[:])







freqtable(cltest2, ytest[:])

0 comments on commit 4bedd76

Please sign in to comment.