缺点:
耗内存,存储所有训练样本,对每个测试样本都要计算和所有训练数据的距离,时间成本高
knn 和 Locally weighted linear regression 思想上非常相似,对每个预测点都需要训练单独训练模型
代码如下:
from numpy import * import os def loaddata(): data=array([[1.0, 0.9], [1.0, 1.0], [0.1, 0.2], [0.0, 0.1]]) labels = ['A', 'A', 'B', 'B'] return data,labels def img2vect(filename): rows=32 cols=32 imgvect=zeros((1,rows*cols)) f=open(filename) for i in range(rows): linstr=f.readline() for j in range(cols): imgvect[0,i*rows+j]=int(linstr[j]) return imgvect def loaddigit(): print('get training sets:') filedir='C:/Users/yourname/Desktop/machin/digit/' trainfilelist=os.listdir(filedir+'trainingDigits') numsamples=len(trainfilelist) train_x=zeros((numsamples,1024)) train_y=[] for i in range(numsamples): filename=trainfilelist[i] train_x[i,:]=img2vect(filedir+'trainingDigits/%s' % filename) label=int(filename.split('_')[0]) train_y.append(label) print( "---Getting testing set..." ) testingFileList = os.listdir(filedir + 'testDigits') numSamples = len(testingFileList) test_x = zeros((numSamples, 1024)) test_y = [] for i in range(numSamples): filename = testingFileList[i] test_x[i, :] = img2vect(filedir + 'testDigits/%s' % filename) label = int(filename.split('_')[0]) test_y.append(label) return train_x, train_y, test_x, test_y def knn(dataset,testdata,labels,k): n=dataset.shape[0] diff=tile(testdata,(n,1))-dataset sdiff=diff**2 sumdist=sum(sdiff,axis=1) dist=sumdist**0.5 dist_sort=argsort(dist) classcount={} for i in range(k): label=labels[dist_sort[i]] classcount[label]=classcount.get(label,0)+1 maxcount=0 for key,value in classcount.items(): if value>maxcount: maxcount=value maxindex=key return maxindex def testHandWritingClass(): train_x, train_y, test_x, test_y=loaddigit() numtestsamples=test_x.shape[0] matchcount=0 for i in range(numtestsamples): predict=knn(train_x,test_x[i],train_y,3) if predict==test_y[i]: matchcount+=1 accuracy=float(matchcount)/numtestsamples print('The classify accuracy is: %.2f%%' % (accuracy * 100)) data,labels=loaddata() testdata=[0.8,0.7] ke=knn(data,testdata,labels,2) print(ke) testHandWritingClass()