%mapminmax为matlab自带的映射函数
[test_wine,pstest]=mapminmax(test_wine′);
%将映射函数的范围参数分别置为0和1
pstest.ymin=0;
pstest.ymax=1;
%对测试集进行[0,1]归一化
[test_wine,pstest]=mapminmax(test_wine,pstest);
%画出原始淀粉光谱图归一化后的图像
figure;
plot(wine,′LineWidth′,2);
title(′原始淀粉光谱图归一化后的图像′,′FontSize′,12);
grid on;
%对训练集和测试集进行转置,以符合libsvm工具箱的数据格式要求
train_wine=train_wine′;
test_wine=test_wine′;
%%选择最佳的SVM参数c&g
%首先进行粗略选择:c&g的变化范围是2^(-10),2^(-9),,2^(10)
[bestacc,bestc,bestg]=SVMcgForClass(train_wine_labels,train_wine,-10,10,-10,10);
%打印粗略选择结果
disp(′打印粗略选择结果′);
str=sprintf(′Best Cross Validation Accuracy=%g%% Best c=%g Best g=%g′,bestacc,bestc,bestg);
disp(str);
%根据粗略选择的结果图再进行精细选择:c的变化范围是2^(-3),2^(-1.5),,2^(10),g的变化范围是2^(-10),2^(-3.5),,2^(4),
[bestacc,bestc,bestg]=SVMcgForClass(train_wine_labels,train_wine,-3,10,-10,4,3,0.5,0.5,0.9);
%打印精细选择结果
disp(′打印精细选择结果′);
str=sprintf(′Best Cross Validation Accuracy=%g%% Best c=%g Best g= %g′,bestacc,bestc,bestg);
disp(str);
%利用最佳的参数进行SVM网络训练
cmd=[′-c′,num2str(bestc),′-g′,num2str(bestg)];
model=svmtrain(train_wine_labels,train_wine,cmd);
%SVM网络预测
[predict_label,accuracy]=svmpredict(test_wine_labels,test_wine,model);
%打印测试集分类准确率
total=length(test_wine_labels);
right=sum(predict_label=test_wine_labels);
disp(′打印测试集分类准确率′);
str=sprintf(′Accuracy=%g%%(%d/%d)′,accuracy(1),right,total);
disp(str);
%%结果分析
%测试集的实际分类和预测分类图
figure;
hold on;
plot(test_wine_labels,′o′);
plot(predict_label,′r*′);
legend(′实际测试集分类′,′预测测试集分类′);
title(′测试集的实际分类和预测分类图′,′FontSize′,10);
%%法2
[train_final,test_final]=scaleForSVM(train_wine,test_wine,0,1);
train_final=train_wine;
test_final=test_wine;
%归一化后可视化
figure;
fori=1:length(train_final(:,1))
plot(train_final(i,1),train_final(i,2),′r*′);
hold on;
end
grid on;
title(′Visualization for 1st dimension & 2nd dimension of scale data′);
%参数c和g寻优选择
%GridSearch Method
[bestCVaccuracy,bestc,bestg]=SVMcgForClass(train_wine_labels,train_final)
cmd=[′-c′,num2str(bestc),′-g′,num2str(bestg)];
%分类预测
model=svmtrain(train_wine_labels,train_final,cmd);
[ptrain_label,train_accuracy]=svmpredict(train_wine_labels,train_final,model);
%train_accuracy
[ptest_label,test_accuracy]=svmpredict(test_wine_labels,test_final,model);
%test_accuracy
%svmplot
figure;
grid on;
svmplot(ptrain_label,train_wine,model);
title(′Train Data Set′);
%%子函数SVMcgForClassm
function[bestacc,bestc,bestg]=SVMcgForClass(train_label,train,cmin,cmax,gmin,gmax,v,cstep,gstep,accstep)
%子函数SVMcgForClass
%输入:
%train_label:训练集标签.要求与libsvm工具箱中要求一致.
%train:训练集.要求与libsvm工具箱中要求一致.
%cmin:惩罚参数c的变化范围的最小值(取以2为底的对数后),即c_min=2^(cmin)默认为-5
%cmax:惩罚参数c的变化范围的最大值(取以2为底的对数后),即c_max=2^(cmax)默认为5
%gmin:参数g的变化范围的最小值(取以2为底的对数后),即g_min=2^(gmin)默认为-5
%gmax:参数g的变化范围的最小值(取以2为底的对数后),即g_min=2^(gmax)默认为5
%v:crossvalidation的参数,即给测试集分为几部分进行crossvalidation默认为3
%cstep:参数c步进的大小默认为1
%gstep:参数g步进的大小默认为1
%accstep:最后显示准确率图时的步进大小默认为1.5
%输出:
%bestacc:Cross Validation过程中的最高分类准确率
% bestc:最佳的参数c
%bestg:最佳的参数g
%about the parameters of SVMcgForClass
ifnargin<10
accstep=1.5;
end
ifnargin<8
accstep=1.5;
cstep=1;
gstep=1;
end
ifnargin<7
accstep=1.5;
v=3;
cstep=1;
gstep=1;
end
ifnargin<6
accstep=1.5;
v=3;
cstep=1;
gstep=1;
gmax=5;
end
ifnargin<5
accstep=1.5;
v=3;
cstep=1;
gstep=1;
gmax=5;
gmin=-5;
end
ifnargin<4
accstep=1.5;
v=3;
cstep=1;
gstep=1;
gmax=5;
gmin=-5;
cmax=5;
end
ifnargin<3
accstep=1.5;
v=3;
cstep=1;
gstep=1;
gmax=5;
gmin=-5;
cmax=5;
cmin=-5;
end
%X:c Y:g cg:accuracy
[X,Y]=meshgrid(cmin:cstep:cmax,gmin:gstep:gmax);
[m,n]=size(X);
cg=zeros(m,n);
%record accuracy with different c & g,and find the best accuracy with the smallest c
bestc=0;
bestg=0;
bestacc=0;
basenum=2;
for i=1:m
for j=1:n
cmd=[′-v′,num2str(v),′-c′,num2str(basenum^X(i,j)),′-g′,num2str(basenum^Y(i,j))];
cg(i,j)=svmtrain(train_label,train,cmd);
ifcg(i,j)>bestacc
bestacc=cg(i,j);
bestc=basenum^X(i,j);
bestg=basenum^Y(i,j);
end
if(cg(i,j)=bestacc&& bestc>basenum^X(i,j))
bestacc=cg(i,j);
bestc=basenum^X(i,j);
bestg=basenum^Y(i,j);
end
end
end
%draw the accuracy with different c & g
figure;
[C,h]=contour(X,Y,cg,60:accstep:100);
clabel(C,h,′FontSize′,10,′Color′,′r′);
xlabel(′log2c′,′FontSize′,10);
ylabel(′log2g′,′FontSize′,10);
title(′参数选择结果图(grid search)′,′FontSize′,10);
grid on;