-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
logistic regression without weight_regularization
- Loading branch information
U-Robert-PC\Robert
authored and
U-Robert-PC\Robert
committed
Oct 2, 2012
0 parents
commit 19417d5
Showing
15 changed files
with
280 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
%% | ||
clear all; | ||
close all; | ||
|
||
%% load data | ||
load('mnist_test.mat'); | ||
load('mnist_train.mat'); | ||
[n,m] = size(inputs_train); | ||
|
||
%learning rate | ||
parameters.learning_rate = 0.01 | ||
%weight regularization parameter | ||
parameters.weight_regularization = 0.5 | ||
%number of iterations | ||
parameters.num_iterations = 200 | ||
%logistics regression weights | ||
% weights = zeros(n+1, 1); | ||
weights = rand(n+1, 1); | ||
|
||
%% verify that your logistic function produces the right gradient, diff should be very close to 0 | ||
%this creates small random data with 20 examples and 10 dimensions and checks the gradient on | ||
%that data | ||
nexamples = 20; | ||
ndimensions = 10; | ||
diff = checkgrad('logistic', ... | ||
randn((ndimensions + 1), 1), ... % weights | ||
0.001,... % perturbation | ||
randn(nexamples, ndimensions), ... % data | ||
rand(nexamples, 1), ... % targets | ||
parameters) | ||
|
||
% begin learning with gradient descent | ||
for t = 1:parameters.num_iterations | ||
% find the negative log likelihood and derivative w.r.t. weights | ||
[f, df, frac_correct_train] = logistic(weights, inputs_train', target_train(1,:)', parameters); | ||
|
||
% find the fraction of correctly classified validation examples | ||
[temp, temp2, frac_correct_valid] = logistic(weights, inputs_test', target_test(1,:)', parameters); | ||
|
||
% | ||
if isnan(f) || isinf(f) | ||
error('nan/inf error'); | ||
end | ||
|
||
% update parameters | ||
weights = weights - parameters.learning_rate .* df; | ||
|
||
% print some stats | ||
fprintf(1, 'ITERATION:%4i LOGL:%4.2f TRAIN FRAC:%2.2f VALID FRAC:%2.2f\n',... | ||
t, f, frac_correct_train*100, frac_correct_valid*100); | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
function d = check(f, X, e, P1, P2, P3, P4, P5); | ||
|
||
% checkgrad checks the derivatives in a function, by comparing them to finite | ||
% differences approximations. The partial derivatives and the approximation | ||
% are printed and the norm of the diffrence divided by the norm of the sum is | ||
% returned as an indication of accuracy. | ||
% | ||
% usage: checkgrad('f', X, e, P1, P2, ...) | ||
% | ||
% where X is the argument and e is the small perturbation used for the finite | ||
% differences. and the P1, P2, ... are optional additional parameters which | ||
% get passed to f. The function f should be of the type | ||
% | ||
% [fX, dfX] = f(X, P1, P2, ...) | ||
% | ||
% where fX is the function value and dfX is a vector of partial derivatives. | ||
% | ||
% Carl Edward Rasmussen, 2001-08-01. | ||
|
||
argstr = [f, '(X']; % assemble function call strings | ||
argstrd = [f, '(X+dx']; | ||
for i = 1:(nargin - 3) | ||
argstr = [argstr, ',P', int2str(i)]; | ||
argstrd = [argstrd, ',P', int2str(i)]; | ||
end | ||
argstr = [argstr, ')']; | ||
argstrd = [argstrd, ')']; | ||
|
||
[y dy] = eval(argstr); % get the partial derivatives dy | ||
|
||
dh = zeros(length(X),1) ; | ||
for j = 1:length(X) | ||
dx = zeros(length(X),1); | ||
dx(j) = dx(j) + e; % perturb a single dimension | ||
y2 = eval(argstrd); | ||
dx = -dx ; | ||
y1 = eval(argstrd); | ||
dh(j) = (y2 - y1)/(2*e); | ||
end | ||
|
||
disp([dy dh]) % print the two vectors | ||
d = norm(dh-dy)/norm(dh+dy); % return norm of diff divided by norm of sum | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
function [label_test] = knn(k, data_train, label_train, data_test) | ||
|
||
error(nargchk(4,4,nargin)); | ||
|
||
D = size(label_train, 1); | ||
|
||
dist = l2_distance(data_train, data_test); | ||
[sorted_dist, nearest] = sort(dist); | ||
nearest = nearest(1:k,:); | ||
|
||
label_test = zeros(D, size(data_test, 2), k); | ||
for i=1:k | ||
label_test(:,:,i) = label_train(:, nearest(i, :)); | ||
end | ||
|
||
label_test = mean(label_test,3); | ||
label_test = label_test == repmat(max(label_test, [], 1), D, 1); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
function d = L2_distance(a,b,df) | ||
% L2_DISTANCE - computes Euclidean distance matrix | ||
% | ||
% E = L2_distance(A,B) | ||
% | ||
% A - (DxM) matrix | ||
% B - (DxN) matrix | ||
% df = 1, force diagonals to be zero; 0 (default), do not force | ||
% | ||
% Returns: | ||
% E - (MxN) Euclidean distances between vectors in A and B | ||
% | ||
% | ||
% Description : | ||
% This fully vectorized (VERY FAST!) m-file computes the | ||
% Euclidean distance between two vectors by: | ||
% | ||
% ||A-B|| = sqrt ( ||A||^2 + ||B||^2 - 2*A.B ) | ||
% | ||
% Example : | ||
% A = rand(400,100); B = rand(400,200); | ||
% d = distance(A,B); | ||
|
||
% Author : Roland Bunschoten | ||
% University of Amsterdam | ||
% Intelligent Autonomous Systems (IAS) group | ||
% Kruislaan 403 1098 SJ Amsterdam | ||
% tel.(+31)20-5257524 | ||
% [email protected] | ||
% Last Rev : Wed Oct 20 08:58:08 MET DST 1999 | ||
% Tested : PC Matlab v5.2 and Solaris Matlab v5.3 | ||
|
||
% Copyright notice: You are free to modify, extend and distribute | ||
% this code granted that the author of the original code is | ||
% mentioned as the original author of the code. | ||
|
||
% Fixed by JBT (3/18/00) to work for 1-dimensional vectors | ||
% and to warn for imaginary numbers. Also ensures that | ||
% output is all real, and allows the option of forcing diagonals to | ||
% be zero. | ||
|
||
if (nargin < 2) | ||
error('Not enough input arguments'); | ||
end | ||
|
||
if (nargin < 3) | ||
df = 0; % by default, do not force 0 on the diagonal | ||
end | ||
|
||
if (size(a,1) ~= size(b,1)) | ||
error('A and B should be of same dimensionality'); | ||
end | ||
|
||
if ~(isreal(a)*isreal(b)) | ||
disp('Warning: running distance.m with imaginary numbers. Results may be off.'); | ||
end | ||
|
||
if (size(a,1) == 1) | ||
a = [a; zeros(1,size(a,2))]; | ||
b = [b; zeros(1,size(b,2))]; | ||
end | ||
|
||
aa=sum(a.*a); bb=sum(b.*b); ab=a'*b; | ||
d = sqrt(repmat(aa',[1 size(bb,2)]) + repmat(bb,[size(aa,2) 1]) - 2*ab); | ||
|
||
% make sure result is all real | ||
d = real(d); | ||
|
||
% force 0 on the diagonal? | ||
if (df==1) | ||
d = d.*(1-eye(size(d))); | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
%%%% | ||
%Calculate log likelihood and derivatives with respect to weights | ||
% Inputs: | ||
% weights - (M+1) by 1 vector of weights, last element corresponds to bias (intercepts) | ||
% data - N by M data matrix where each row corresponds to one data point | ||
% targets - N by 1 vector of targets class probabilites | ||
% parameters - structure with additional parameters | ||
% Outputs: | ||
% f - cross entropy | ||
% df - (M+1) by 1 vector of derivatives | ||
% frac_correct - fraction of correctly classified examples | ||
%%%% | ||
function [f, df, frac_correct] = logistic(weights, data, targets, parameters) | ||
|
||
%get the dimention of the output | ||
[n,m] = size(targets); | ||
|
||
%append a column of zero for w0 | ||
data = [ones(size(data,1),1), data]; | ||
|
||
%compute the prob of our classification | ||
p = sigmoid(data*weights); | ||
|
||
%compute the cross-entropy for where targets only takes value of 0 or 1 | ||
f = -sum(targets .* log(p) + (1 - targets) .* log(1 - p)); | ||
|
||
%compute the derivative | ||
for k = 1:size(weights) | ||
df(k,1) = sum((p-targets) .* data(:, k)); | ||
end | ||
|
||
%compute the correctly predicted output | ||
p = (p >= 0.5); | ||
frac_correct = sum(targets == p) / size(targets, 1); | ||
|
||
end |
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
function plotDigits(digitPats) | ||
|
||
% After calling this function, re-size your figure window so that | ||
% each pixel is approximately square. | ||
|
||
colormap('gray'); | ||
j=0; | ||
examplesPerDigit = size(digitPats,2)/2; | ||
clf; | ||
subplot(2,examplesPerDigit,1); | ||
for dig = 1:2 | ||
for pat = 1:examplesPerDigit | ||
j = j+1; | ||
axis off, subplot(2,examplesPerDigit,j); | ||
imagesc(reshape(digitPats(:,(dig-1)*examplesPerDigit+pat),28,28)'); | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
function plotKnn(data_train, label_train, data_test, label_test) | ||
|
||
if (nargin < 4) | ||
error('Not enough input arguments'); | ||
end | ||
|
||
hold on; | ||
|
||
for k=0:29 | ||
result = knn(k+1, data_train, label_train, data_test); | ||
total = sum( abs( label_test(1,:) - result(1,:))); | ||
percent_correct = ( size(label_test,2) - total )/size(label_test,2) ; | ||
plot(k+1, percent_correct, 'b*'); | ||
end | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
function plotSpiral(inputs,target,flag) | ||
|
||
if nargin <3 | ||
flag = 1; | ||
end | ||
|
||
hold on; | ||
|
||
if flag ==1 | ||
for i = 1:size(inputs,2) | ||
if target(1,i) == 0 | ||
plot(inputs(1,i),inputs(2,i),'r*'); | ||
else | ||
plot(inputs(1,i),inputs(2,i),'bo'); | ||
end | ||
end | ||
elseif flag == 0 | ||
for i = 1:size(inputs,2) | ||
plot(inputs(1,i),inputs(2,i),'y.'); | ||
end | ||
else | ||
for i = 1:size(inputs,2) | ||
plot(inputs(1,i),inputs(2,i),'g.'); | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
function [output] = sigmoid(totalIn) | ||
output = 1.0 ./ (ones(size(totalIn)) + exp(-totalIn)); | ||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.