说明:
1. 此处是台大林轩田老师主页上的hw7,对应coursera上“机器学习技法”作业三;
2. 本文给出大作业(13-15题)的代码;
3. Matlab代码;
4. 非职业码农,代码质量不高,变量命名也不规范,凑合着看吧,不好意思;
5. 如有问题,欢迎指教,QQ:50834。
题目13-15,分为主程序和四个函数
主程序:
clear all; close all; global dtree_node; data_trn = csvread('hw7_train.dat'); data_tst = csvread('hw7_test.dat'); xtrn = data_trn(:,1:end-1); ytrn = data_trn(:,end); [N,k] = size(xtrn); xtst = data_tst(:,1:end-1); ytst = data_tst(:,end); [Ntst,k] = size(xtst); clear data_trn clear data_tst dtree_node = []; node0 = [1,0,0,0,0,0,0]; % num, father, sign, dim, th, son +, son - % sign, dim, th, son+, son- 待修改 hw7_CART_train(xtrn,ytrn, node0); fprintf('=====================================================================\n'); % fprintf(' Node# | Fth.Node | Sign | Deci.Dim | Thrshhld | SonNode+ | SonNode-\n'); % for node = 1:length(dtree_node), % fprintf('%4d %9d %8d %8d %13.4f %7d %10d\n', dtree_node(node,:)); % end; % fprintf('---------------------------------------------------------------------\n'); fprintf('Decision tree for classification\n'); for node = 1:length(dtree_node), fprintf('%3d ', dtree_node(node,1)); if dtree_node(node,4) == 0, fprintf('Leaf node with class = %2d\n', dtree_node(node,3)); else fprintf('if x%d>=%6.4f then node %d else node %d\n', dtree_node(node,4), ... dtree_node(node,5), dtree_node(node,6), dtree_node(node,7)); end; end; fprintf('---------------------------------------------------------------------\n'); ypred = hw7_CART_pred(dtree_node, xtrn); fprintf(' Ein = %6.2f%% \n', sum(ypred~=ytrn)/N*100); fprintf('---------------------------------------------------------------------\n'); figure; hold on; idxp = ytrn>0; idxm = ytrn<0; plot(xtrn(idxp,1),xtrn(idxp,2),'bo'); plot(xtrn(idxm,1),xtrn(idxm,2),'ro'); idxp = ypred>0; idxm = ypred<0; plot(xtrn(idxp,1),xtrn(idxp,2),'b*'); plot(xtrn(idxm,1),xtrn(idxm,2),'r*'); ypred = hw7_CART_pred(dtree_node, xtst); fprintf(' Eout = %6.2f%% \n', sum(ypred~=ytst)/Ntst*100); fprintf('---------------------------------------------------------------------\n'); figure; hold on; idxp = ytst>0; idxm = ytst<0; plot(xtst(idxp,1),xtst(idxp,2),'bo'); plot(xtst(idxm,1),xtst(idxm,2),'ro'); idxp = ypred>0; idxm = ypred<0; plot(xtst(idxp,1),xtst(idxp,2),'b*'); plot(xtst(idxm,1),xtst(idxm,2),'r*'); 函数1:
function gini = hw7_gini(y) y_uni = unique(y); y_num = length(y_uni); N = length(y); gini = 1; for i = 1:y_num, gini = gini - (sum(y==y_uni(i))/N)^2; end; end 函数2:
function [s, dim, thresh] = hw7_deci_stump_impurity(x, y) [N,k]=size(x); bx = hw7_gini(y); thresh = -Inf; s = sign(sum(y)); dim = 1; for feat = 1:k, [xsort, idxsort] = sort(x(:,feat)); for rec = 1:N-1, N1 = rec; bx1 = hw7_gini(y(idxsort(1:rec))); N2 = N-rec; bx2 = hw7_gini(y(idxsort(rec+1:end))); bx_tmp = (N1*bx1+N2*bx2)/N; if bx_tmp < bx, bx = bx_tmp; thresh = (xsort(rec)+xsort(rec+1))/2; s = sign(sum(y(idxsort(rec+1:end))==1)/N2-sum(y(idxsort(1:rec))==1)/N1); if s==0 s=1; end; dim = feat; end; end; end; end 函数3:
function hw7_CART_train(x,y,node0) global dtree_node; dtree_node = [dtree_node;node0]; fn_num = node0(1); y_uniq = unique(y); y_num = length(y_uniq); if y_num ~= 1, [s, dim, thresh] = hw7_deci_stump_impurity(x, y); dtree_node(fn_num,3)=s; dtree_node(fn_num,4)=dim; dtree_node(fn_num,5)=thresh; idxp = x(:,dim)>=thresh; idxm = x(:,dim)<thresh; dtree_size = size(dtree_node); nodep = [dtree_size(1)+1,fn_num,0,0,0,0,0]; dtree_node(fn_num,6)=dtree_size(1)+1; hw7_CART_train(x(idxp,:),y(idxp),nodep); dtree_size = size(dtree_node); nodem = [dtree_size(1)+1,fn_num,0,0,0,0,0]; dtree_node(fn_num,7)=dtree_size(1)+1; hw7_CART_train(x(idxm,:),y(idxm),nodem); else dtree_node(fn_num,3)=y_uniq; end; 函数4:
function y = hw7_CART_pred(dtree, x) [N,k] = size(x); y = zeros(N,1); for i = 1:N, %y(i) = hw7_CART_1pred(dtree,x(i,:)); next_node = 1; while dtree(next_node,4)>1e-10, if x(i,dtree(next_node,4))>=dtree(next_node,5) next_node = dtree(next_node,6); else next_node = dtree(next_node,7); end; end; y(i) = dtree(next_node,3); end; 运行结果:
===================================================================== Decision tree for classification 1 if x2>=0.6262 then node 2 else node 5 2 if x1>=0.8782 then node 3 else node 4 3 Leaf node with class = 1 4 Leaf node with class = -1 5 if x1>=0.2244 then node 6 else node 19 6 if x1>=0.5415 then node 7 else node 12 7 if x2>=0.2859 then node 8 else node 9 8 Leaf node with class = 1 9 if x2>=0.2660 then node 10 else node 11 10 Leaf node with class = -1 11 Leaf node with class = 1 12 if x2>=0.3586 then node 13 else node 16 13 if x1>=0.2608 then node 14 else node 15 14 Leaf node with class = -1 15 Leaf node with class = 1 16 if x1>=0.5016 then node 17 else node 18 17 Leaf node with class = -1 18 Leaf node with class = 1 19 if x2>=0.1152 then node 20 else node 21 20 Leaf node with class = -1 21 Leaf node with class = 1 --------------------------------------------------------------------- Ein = 0.00% --------------------------------------------------------------------- Eout = 12.60% ---------------------------------------------------------------------