Moving course1 to course1 subdir.
This commit is contained in:
192
machine_learning/course1/mlclass-ex6-008/mlclass-ex6/svmTrain.m
Executable file
192
machine_learning/course1/mlclass-ex6-008/mlclass-ex6/svmTrain.m
Executable file
@@ -0,0 +1,192 @@
|
||||
function [model] = svmTrain(X, Y, C, kernelFunction, ...
|
||||
tol, max_passes)
|
||||
%SVMTRAIN Trains an SVM classifier using a simplified version of the SMO
|
||||
%algorithm.
|
||||
% [model] = SVMTRAIN(X, Y, C, kernelFunction, tol, max_passes) trains an
|
||||
% SVM classifier and returns trained model. X is the matrix of training
|
||||
% examples. Each row is a training example, and the jth column holds the
|
||||
% jth feature. Y is a column matrix containing 1 for positive examples
|
||||
% and 0 for negative examples. C is the standard SVM regularization
|
||||
% parameter. tol is a tolerance value used for determining equality of
|
||||
% floating point numbers. max_passes controls the number of iterations
|
||||
% over the dataset (without changes to alpha) before the algorithm quits.
|
||||
%
|
||||
% Note: This is a simplified version of the SMO algorithm for training
|
||||
% SVMs. In practice, if you want to train an SVM classifier, we
|
||||
% recommend using an optimized package such as:
|
||||
%
|
||||
% LIBSVM (http://www.csie.ntu.edu.tw/~cjlin/libsvm/)
|
||||
% SVMLight (http://svmlight.joachims.org/)
|
||||
%
|
||||
%
|
||||
|
||||
if ~exist('tol', 'var') || isempty(tol)
|
||||
tol = 1e-3;
|
||||
end
|
||||
|
||||
if ~exist('max_passes', 'var') || isempty(max_passes)
|
||||
max_passes = 5;
|
||||
end
|
||||
|
||||
% Data parameters
|
||||
m = size(X, 1);
|
||||
n = size(X, 2);
|
||||
|
||||
% Map 0 to -1
|
||||
Y(Y==0) = -1;
|
||||
|
||||
% Variables
|
||||
alphas = zeros(m, 1);
|
||||
b = 0;
|
||||
E = zeros(m, 1);
|
||||
passes = 0;
|
||||
eta = 0;
|
||||
L = 0;
|
||||
H = 0;
|
||||
|
||||
% Pre-compute the Kernel Matrix since our dataset is small
|
||||
% (in practice, optimized SVM packages that handle large datasets
|
||||
% gracefully will _not_ do this)
|
||||
%
|
||||
% We have implemented optimized vectorized version of the Kernels here so
|
||||
% that the svm training will run faster.
|
||||
if strcmp(func2str(kernelFunction), 'linearKernel')
|
||||
% Vectorized computation for the Linear Kernel
|
||||
% This is equivalent to computing the kernel on every pair of examples
|
||||
K = X*X';
|
||||
elseif strfind(func2str(kernelFunction), 'gaussianKernel')
|
||||
% Vectorized RBF Kernel
|
||||
% This is equivalent to computing the kernel on every pair of examples
|
||||
X2 = sum(X.^2, 2);
|
||||
K = bsxfun(@plus, X2, bsxfun(@plus, X2', - 2 * (X * X')));
|
||||
K = kernelFunction(1, 0) .^ K;
|
||||
else
|
||||
% Pre-compute the Kernel Matrix
|
||||
% The following can be slow due to the lack of vectorization
|
||||
K = zeros(m);
|
||||
for i = 1:m
|
||||
for j = i:m
|
||||
K(i,j) = kernelFunction(X(i,:)', X(j,:)');
|
||||
K(j,i) = K(i,j); %the matrix is symmetric
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
% Train
|
||||
fprintf('\nTraining ...');
|
||||
dots = 12;
|
||||
while passes < max_passes,
|
||||
|
||||
num_changed_alphas = 0;
|
||||
for i = 1:m,
|
||||
|
||||
% Calculate Ei = f(x(i)) - y(i) using (2).
|
||||
% E(i) = b + sum (X(i, :) * (repmat(alphas.*Y,1,n).*X)') - Y(i);
|
||||
E(i) = b + sum (alphas.*Y.*K(:,i)) - Y(i);
|
||||
|
||||
if ((Y(i)*E(i) < -tol && alphas(i) < C) || (Y(i)*E(i) > tol && alphas(i) > 0)),
|
||||
|
||||
% In practice, there are many heuristics one can use to select
|
||||
% the i and j. In this simplified code, we select them randomly.
|
||||
j = ceil(m * rand());
|
||||
while j == i, % Make sure i \neq j
|
||||
j = ceil(m * rand());
|
||||
end
|
||||
|
||||
% Calculate Ej = f(x(j)) - y(j) using (2).
|
||||
E(j) = b + sum (alphas.*Y.*K(:,j)) - Y(j);
|
||||
|
||||
% Save old alphas
|
||||
alpha_i_old = alphas(i);
|
||||
alpha_j_old = alphas(j);
|
||||
|
||||
% Compute L and H by (10) or (11).
|
||||
if (Y(i) == Y(j)),
|
||||
L = max(0, alphas(j) + alphas(i) - C);
|
||||
H = min(C, alphas(j) + alphas(i));
|
||||
else
|
||||
L = max(0, alphas(j) - alphas(i));
|
||||
H = min(C, C + alphas(j) - alphas(i));
|
||||
end
|
||||
|
||||
if (L == H),
|
||||
% continue to next i.
|
||||
continue;
|
||||
end
|
||||
|
||||
% Compute eta by (14).
|
||||
eta = 2 * K(i,j) - K(i,i) - K(j,j);
|
||||
if (eta >= 0),
|
||||
% continue to next i.
|
||||
continue;
|
||||
end
|
||||
|
||||
% Compute and clip new value for alpha j using (12) and (15).
|
||||
alphas(j) = alphas(j) - (Y(j) * (E(i) - E(j))) / eta;
|
||||
|
||||
% Clip
|
||||
alphas(j) = min (H, alphas(j));
|
||||
alphas(j) = max (L, alphas(j));
|
||||
|
||||
% Check if change in alpha is significant
|
||||
if (abs(alphas(j) - alpha_j_old) < tol),
|
||||
% continue to next i.
|
||||
% replace anyway
|
||||
alphas(j) = alpha_j_old;
|
||||
continue;
|
||||
end
|
||||
|
||||
% Determine value for alpha i using (16).
|
||||
alphas(i) = alphas(i) + Y(i)*Y(j)*(alpha_j_old - alphas(j));
|
||||
|
||||
% Compute b1 and b2 using (17) and (18) respectively.
|
||||
b1 = b - E(i) ...
|
||||
- Y(i) * (alphas(i) - alpha_i_old) * K(i,j)' ...
|
||||
- Y(j) * (alphas(j) - alpha_j_old) * K(i,j)';
|
||||
b2 = b - E(j) ...
|
||||
- Y(i) * (alphas(i) - alpha_i_old) * K(i,j)' ...
|
||||
- Y(j) * (alphas(j) - alpha_j_old) * K(j,j)';
|
||||
|
||||
% Compute b by (19).
|
||||
if (0 < alphas(i) && alphas(i) < C),
|
||||
b = b1;
|
||||
elseif (0 < alphas(j) && alphas(j) < C),
|
||||
b = b2;
|
||||
else
|
||||
b = (b1+b2)/2;
|
||||
end
|
||||
|
||||
num_changed_alphas = num_changed_alphas + 1;
|
||||
|
||||
end
|
||||
|
||||
end
|
||||
|
||||
if (num_changed_alphas == 0),
|
||||
passes = passes + 1;
|
||||
else
|
||||
passes = 0;
|
||||
end
|
||||
|
||||
fprintf('.');
|
||||
dots = dots + 1;
|
||||
if dots > 78
|
||||
dots = 0;
|
||||
fprintf('\n');
|
||||
end
|
||||
if exist('OCTAVE_VERSION')
|
||||
fflush(stdout);
|
||||
end
|
||||
end
|
||||
fprintf(' Done! \n\n');
|
||||
|
||||
% Save the model
|
||||
idx = alphas > 0;
|
||||
model.X= X(idx,:);
|
||||
model.y= Y(idx);
|
||||
model.kernelFunction = kernelFunction;
|
||||
model.b= b;
|
||||
model.alphas= alphas(idx);
|
||||
model.w = ((alphas.*Y)'*X)';
|
||||
|
||||
end
|
||||
Reference in New Issue
Block a user