tensorflow实现knn算法

    xiaoxiao2021-03-25  67

    knn算法介绍

    knn算法是机器学习中最简单的算法。其原理类似于古语“近朱者赤近墨者黑”,即同类物体的差异性小,异类差异性大,而这种差异往往是用“距离”表示。“距离”的度量一般采用欧氏距离。

    算法思路

    1.计算待分类的样本和样本空间中已标记的样本的欧氏距离。(如图中绿点为待分类样本,要计算绿点与图中所有点的距离)

    2.取距离最短的k个点,k个点进行投票,票数最多的类为待测样本的类。(若k为3,则图中实线圆中的点是距离绿点最短的点,其中三角形有两个,正方形1个,所以绿点为三角形;若k为5,则图中虚线中的点为最近邻点,其中正方形有3个,三角形2个,所以绿点为正方形。由此可知,k的取值会影响分类的结果)

    算法的优缺点

    1.优点

    算法简单有效

    2.缺点

    一方面计算量大。当训练集比较大的时候,每一个样本分类都要计算与所有的已标记样本的距离。目前常用的解决方法是事先对已知样本点进行剪辑,事先去除对分类作用不大的样本(例如在样本空间进行划分区域)。另一方面是当已标记样本是不平衡,分类会向占样本多数的类倾斜。解决方案是引进权重。

    tensorflow简单实现knn

    import tensorflow as tf import numpy as np import matplotlib.pyplot as plt from tensorflow.examples.tutorials import mnist mnist_image=mnist.input_data.read_data_sets("C:\\Users\\Administrator\\.keras\\datasets\\",one_hot=True) pixels,real_values=mnist_image.train.next_batch(10) # n=5 # image=pixels[n,:] # image=np.reshape(image, [28,28]) # plt.imshow(image) # plt.show() traindata,trainlabel=mnist_image.train.next_batch(100) testdata,testlabel=mnist_image.test.next_batch(10) traindata_tensor=tf.placeholder('float',[None,784]) testdata_tensor=tf.placeholder('float',[784]) distance=tf.reduce_sum(tf.abs(tf.add(traindata_tensor,tf.neg(testdata_tensor))),reduction_indices=1) pred = tf.arg_min(distance,0) test_num=10 accuracy=0 init=tf.global_variables_initializer() with tf.Session() as sess: sess.run(init) for i in range(test_num): idx=sess.run(pred,feed_dict={traindata_tensor:traindata,testdata_tensor:testdata[i]}) print('test No.%d,the real label %d, the predict label %d'%(i,np.argmax(testlabel[i]),np.argmax(trainlabel[idx]))) if np.argmax(testlabel[i])==np.argmax(trainlabel[idx]): accuracy+=1 print("result:%f"%(1.0*accuracy/test_num))

    输出

    Extracting C:\Users\Administrator\.keras\datasets\train-images-idx3-ubyte.gz Extracting C:\Users\Administrator\.keras\datasets\train-labels-idx1-ubyte.gz Extracting C:\Users\Administrator\.keras\datasets\t10k-images-idx3-ubyte.gz Extracting C:\Users\Administrator\.keras\datasets\t10k-labels-idx1-ubyte.gz test No.0,the real label 7, the predict label 7 test No.1,the real label 2, the predict label 2 test No.2,the real label 1, the predict label 1 test No.3,the real label 0, the predict label 0 test No.4,the real label 4, the predict label 4 test No.5,the real label 1, the predict label 1 test No.6,the real label 4, the predict label 4 test No.7,the real label 9, the predict label 9 test No.8,the real label 5, the predict label 9 test No.9,the real label 9, the predict label 9 result:0.900000
    转载请注明原文地址: https://ju.6miu.com/read-36418.html

    最新回复(0)